19#include <mlir/IR/IRMapping.h>
20#include <mlir/IR/OpImplementation.h>
22#include <llvm/ADT/MapVector.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/StringSet.h>
54 if (succeeded(parentFuncOpt)) {
55 FuncDefOp parentFunc = parentFuncOpt.value();
57 if (parentFunc.
getSymName().compare(funcName) == 0) {
67 assert(llvm::isa<StructDefOp>(structOp));
68 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
69 if (!bodyRegion.empty()) {
70 bodyRegion.front().walk([](
FuncDefOp funcDef) {
87 std::string prefix = std::string();
88 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
90 prefix += symbol.getName();
93 return origin->emitOpError().append(
99static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
109 SymbolTableCollection &tables,
StructDefOp expectedStruct, Type actualType, Operation *origin,
112 if (
StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
113 auto actualStructOpt =
115 if (failed(actualStructOpt)) {
116 return origin->emitError().append(
118 actualStructType.getNameRef(),
'"'
121 StructDefOp actualStruct = actualStructOpt.value().get();
122 if (actualStruct != expectedStruct) {
124 .attachNote(actualStruct.getLoc())
125 .append(
"uses this type instead");
132 if (ArrayAttr tyParams = actualStructType.getParams()) {
133 if (failed(
verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
139 .attachNote(actualStruct.getLoc())
154 assert(succeeded(pathRes));
161 if (succeeded(pathToExpected)) {
162 ss << pathToExpected.value();
169 ss <<
'<' << attr <<
'>';
176 for (Attribute attr : params) {
177 assert(llvm::isa<FlatSymbolRefAttr>(attr));
178 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
188 assert(succeeded(res));
195 llvm::StringSet<> uniqNames;
196 for (Attribute attr : params) {
197 assert(llvm::isa<FlatSymbolRefAttr>(attr));
198 StringRef name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
199 if (!uniqNames.insert(name).second) {
200 return this->emitOpError().append(
"has more than one parameter named \"@", name,
'"');
204 for (Attribute attr : params) {
206 if (succeeded(res)) {
207 return this->emitOpError()
208 .append(
"parameter name \"@")
209 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
210 .append(
"\" conflicts with an existing symbol")
211 .attachNote(res->get()->getLoc())
212 .append(
"symbol already defined here");
221inline LogicalResult checkMainFuncParamType(Type pType,
FuncDefOp inFunc,
bool appendSelf) {
224 }
else if (
auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
232 <<
"\" function parameters must be one of: {";
240 return inFunc.emitError(message);
243inline LogicalResult verifyStructComputeConstrain(
244 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
254 ArrayRef<Type> computeParams = computeFunc.
getFunctionType().getInputs();
255 ArrayRef<Type> constrainParams = constrainFunc.
getFunctionType().getInputs().drop_front();
259 return structDef.emitError().append(
266 for (Type t : computeParams) {
267 if (failed(checkMainFuncParamType(t, computeFunc,
false))) {
271 for (Type t : constrainParams) {
272 if (failed(checkMainFuncParamType(t, constrainFunc,
true))) {
279 return constrainFunc.emitError()
282 "\" function argument types (sans the first one) to match \"@",
FUNC_NAME_COMPUTE,
283 "\" function argument types"
285 .attachNote(computeFunc.getLoc())
292inline LogicalResult verifyStructProduct(
StructDefOp structDef, FuncDefOp productFunc) {
298 ArrayRef<Type> productParams = productFunc.
getFunctionType().getInputs();
301 return structDef.emitError().append(
305 for (Type t : productParams) {
306 if (failed(checkMainFuncParamType(t, productFunc,
false))) {
318 std::optional<FuncDefOp> foundCompute = std::nullopt;
319 std::optional<FuncDefOp> foundConstrain = std::nullopt;
320 std::optional<FuncDefOp> foundProduct = std::nullopt;
328 if (!bodyRegion.empty()) {
329 for (Operation &op : bodyRegion.front()) {
330 if (!llvm::isa<FieldDefOp>(op)) {
331 if (
FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
332 if (funcDef.nameIsCompute()) {
334 return structFuncDefError(funcDef.getOperation())
339 return structFuncDefError(funcDef.getOperation())
342 foundCompute = std::make_optional(funcDef);
343 }
else if (funcDef.nameIsConstrain()) {
345 return structFuncDefError(funcDef.getOperation())
349 if (foundConstrain) {
350 return structFuncDefError(funcDef.getOperation())
353 foundConstrain = std::make_optional(funcDef);
354 }
else if (funcDef.nameIsProduct()) {
356 return structFuncDefError(funcDef.getOperation())
360 if (foundConstrain) {
361 return structFuncDefError(funcDef.getOperation())
366 return structFuncDefError(funcDef.getOperation())
369 foundProduct = std::make_optional(funcDef);
373 return structFuncDefError(funcDef.getOperation())
374 <<
"found \"@" << funcDef.getSymName() <<
'"';
377 return op.emitOpError()
386 if (!foundCompute.has_value() && foundConstrain.has_value()) {
390 if (!foundConstrain.has_value() && foundCompute.has_value()) {
396 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
397 return structFuncDefError(getOperation())
402 if (foundCompute && foundConstrain) {
403 return verifyStructComputeConstrain(*
this, *foundCompute, *foundConstrain);
405 return verifyStructProduct(*
this, *foundProduct);
409 for (Operation &op : *getBody()) {
410 if (
FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
411 if (fieldName.compare(fieldDef.getSymNameAttr()) == 0) {
420 std::vector<FieldDefOp> res;
421 for (Operation &op : *getBody()) {
422 if (
FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
423 res.push_back(fieldDef);
439 return llvm::dyn_cast<FuncDefOp>(computeFunc);
446 return llvm::dyn_cast<FuncDefOp>(constrainFunc);
458 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
465 props.column = odsBuilder.getUnitAttr();
470 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type,
bool isColumn
472 build(odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isColumn);
476 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
477 ArrayRef<NamedAttribute> attributes,
bool isColumn
479 assert(operands.size() == 0u &&
"mismatched number of parameters");
480 odsState.addOperands(operands);
481 odsState.addAttributes(attributes);
482 assert(resultTypes.size() == 0u &&
"mismatched number of return types");
483 odsState.addTypes(resultTypes);
485 odsState.getOrAddProperties<
Properties>().column = odsBuilder.getUnitAttr();
491 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
493 getOperation()->removeAttr(PublicAttr::name);
498verifyFieldDefTypeImpl(Type fieldType, SymbolTableCollection &tables, Operation *origin) {
499 if (
StructType fieldStructType = llvm::dyn_cast<StructType>(fieldType)) {
503 if (failed(fieldTypeRes)) {
507 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
508 if (fieldTypeRes.value() == parentRes.value()) {
509 return origin->emitOpError()
510 .append(
"type is circular")
511 .attachNote(parentRes.value().getLoc())
512 .append(
"references parent component defined here");
521 Type fieldType = this->
getType();
522 if (failed(verifyFieldDefTypeImpl(fieldType, tables, *
this))) {
531 return emitOpError() <<
"marked as column can only contain felts, arrays of column types, or "
532 "structs with columns, but field has type "
543FailureOr<SymbolLookupResult<FieldDefOp>>
545 Operation *op = refOp.getOperation();
547 if (failed(structDefRes)) {
551 tables, SymbolRefAttr::get(refOp->getContext(), refOp.
getFieldName()),
552 std::move(*structDefRes), op
559 return std::move(res.value());
562static FailureOr<SymbolLookupResult<FieldDefOp>>
570 return getFieldDefOpImpl(refOp, tables, tyStruct);
573static LogicalResult verifySymbolUsesImpl(
574 FieldRefOpInterface refOp, SymbolTableCollection &tables, SymbolLookupResult<FieldDefOp> &field
577 Type actualType = refOp.
getVal().getType();
578 Type fieldType = field.
get().getType();
580 return refOp->emitOpError() <<
"has wrong type; expected " << fieldType <<
", got "
587LogicalResult verifySymbolUsesImpl(
FieldRefOpInterface refOp, SymbolTableCollection &tables) {
589 auto field = findField(refOp, tables);
593 return verifySymbolUsesImpl(refOp, tables, *field);
598FailureOr<SymbolLookupResult<FieldDefOp>>
604 auto field = findField(*
this, tables);
608 if (failed(verifySymbolUsesImpl(*
this, tables, *field))) {
613 return emitOpError(
"cannot read with table offset from a field that is not a column")
614 .attachNote(field->
get().getLoc())
615 .append(
"field defined here");
624 if (failed(getParentRes)) {
631 return verifySymbolUsesImpl(*
this, tables);
639 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr field
643 state.addTypes(resultType);
649 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr field,
650 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
653 assert(mapOperands.empty() || numDims.has_value());
655 state.addTypes(resultType);
656 if (numDims.has_value()) {
658 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
664 props.setFieldName(FlatSymbolRefAttr::get(field));
665 props.setTableOffset(dist);
669 OpBuilder & , OperationState &odsState, TypeRange resultTypes,
670 ValueRange operands, ArrayRef<NamedAttribute> attrs
672 odsState.addTypes(resultTypes);
673 odsState.addOperands(operands);
674 odsState.addAttributes(attrs);
678 SmallVector<AffineMapAttr, 1> mapAttrs;
679 if (AffineMapAttr map =
680 llvm::dyn_cast_if_present<AffineMapAttr>(
getTableOffset().value_or(
nullptr))) {
681 mapAttrs.push_back(map);
698 if (failed(getParentRes)) {
701 if (failed(
checkSelfType(tables, *getParentRes, this->getType(), *
this,
"result"))) {
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
static constexpr ::llvm::StringLiteral name
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::TypedValue<::llzk::component::StructType > getResult()
FoldAdaptor::Properties Properties
void setPublicAttr(bool newValue=true)
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isColumn=false)
static constexpr ::llvm::StringLiteral getOperationName()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
FoldAdaptor::Properties Properties
::mlir::OperandRangeRange getMapOperands()
::llvm::LogicalResult verify()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr field)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::std::optional<::mlir::Attribute > getTableOffset()
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::llvm::StringRef getFieldName()
Gets the field name attribute value from the FieldRefOp.
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
FieldDefOp getFieldDef(::mlir::StringAttr fieldName)
Gets the FieldDefOp that defines the field in this structure with the given name, if present.
::mlir::Region & getBodyRegion()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static constexpr ::llvm::StringLiteral getOperationName()
::llzk::function::FuncDefOp getConstrainOrProductFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
::llvm::StringRef getSymName()
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeOrProductFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::llvm::LogicalResult verifyRegions()
::mlir::ArrayAttr getConstParamsAttr()
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool isMainComponent()
Return true iff this StructDefOp is named "Main".
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
static constexpr ::llvm::StringLiteral name
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
::mlir::FunctionType getFunctionType()
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
::llvm::StringRef getSymName()
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
static constexpr ::llvm::StringLiteral getOperationName()
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
FailureOr< StructDefOp > verifyInStruct(Operation *op)
bool isInStructFunctionNamed(Operation *op, char const *funcName)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
bool isSignalType(Type type)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
void setFieldName(const ::mlir::FlatSymbolRefAttr &propValue)