19#include <mlir/IR/IRMapping.h>
20#include <mlir/IR/OpImplementation.h>
22#include <llvm/ADT/MapVector.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/StringSet.h>
54 if (succeeded(parentFuncOpt)) {
55 FuncDefOp parentFunc = parentFuncOpt.value();
57 if (parentFunc.
getSymName().compare(funcName) == 0) {
67 assert(llvm::isa<StructDefOp>(structOp));
68 llvm::cast<StructDefOp>(structOp).getBody()->walk([](
FuncDefOp funcDef) {
84 std::string prefix = std::string();
85 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
87 prefix += symbol.getName();
90 return origin->emitOpError().append(
96static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
106 SymbolTableCollection &tables,
StructDefOp expectedStruct, Type actualType, Operation *origin,
109 if (
StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
110 auto actualStructOpt =
112 if (failed(actualStructOpt)) {
113 return origin->emitError().append(
115 actualStructType.getNameRef(),
'"'
118 StructDefOp actualStruct = actualStructOpt.value().get();
119 if (actualStruct != expectedStruct) {
121 .attachNote(actualStruct.getLoc())
122 .append(
"uses this type instead");
129 if (ArrayAttr tyParams = actualStructType.getParams()) {
130 if (failed(
verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
136 .attachNote(actualStruct.getLoc())
151 assert(succeeded(pathRes));
158 if (succeeded(pathToExpected)) {
159 ss << pathToExpected.value();
166 ss <<
'<' << attr <<
'>';
173 for (Attribute attr : params) {
174 assert(llvm::isa<FlatSymbolRefAttr>(attr));
175 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
185 assert(succeeded(res));
192 llvm::StringSet<> uniqNames;
193 for (Attribute attr : params) {
194 assert(llvm::isa<FlatSymbolRefAttr>(attr));
195 StringRef
name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
196 if (!uniqNames.insert(
name).second) {
197 return this->emitOpError().append(
"has more than one parameter named \"@",
name,
'"');
201 for (Attribute attr : params) {
203 if (succeeded(res)) {
204 return this->emitOpError()
205 .append(
"parameter name \"@")
206 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
207 .append(
"\" conflicts with an existing symbol")
208 .attachNote(res->get()->getLoc())
209 .append(
"symbol already defined here");
218inline LogicalResult checkMainFuncParamType(Type pType,
FuncDefOp inFunc,
bool appendSelf) {
221 }
else if (
auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
229 <<
"\" function parameters must be one of: {";
237 return inFunc.emitError(message);
240inline LogicalResult verifyStructComputeConstrain(
241 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
251 ArrayRef<Type> computeParams = computeFunc.
getFunctionType().getInputs();
252 ArrayRef<Type> constrainParams = constrainFunc.
getFunctionType().getInputs().drop_front();
256 return structDef.emitError().append(
263 for (Type t : computeParams) {
264 if (failed(checkMainFuncParamType(t, computeFunc,
false))) {
268 for (Type t : constrainParams) {
269 if (failed(checkMainFuncParamType(t, constrainFunc,
true))) {
276 return constrainFunc.emitError()
279 "\" function argument types (sans the first one) to match \"@",
FUNC_NAME_COMPUTE,
280 "\" function argument types"
282 .attachNote(computeFunc.getLoc())
289inline LogicalResult verifyStructProduct(
StructDefOp structDef, FuncDefOp productFunc) {
295 ArrayRef<Type> productParams = productFunc.
getFunctionType().getInputs();
298 return structDef.emitError().append(
302 for (Type t : productParams) {
303 if (failed(checkMainFuncParamType(t, productFunc,
false))) {
315 std::optional<FuncDefOp> foundCompute = std::nullopt;
316 std::optional<FuncDefOp> foundConstrain = std::nullopt;
317 std::optional<FuncDefOp> foundProduct = std::nullopt;
324 for (Operation &op : *getBody()) {
325 if (!llvm::isa<FieldDefOp>(op)) {
326 if (
FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
327 if (funcDef.nameIsCompute()) {
329 return structFuncDefError(funcDef.getOperation())
334 return structFuncDefError(funcDef.getOperation())
337 foundCompute = std::make_optional(funcDef);
338 }
else if (funcDef.nameIsConstrain()) {
340 return structFuncDefError(funcDef.getOperation())
344 if (foundConstrain) {
345 return structFuncDefError(funcDef.getOperation())
348 foundConstrain = std::make_optional(funcDef);
349 }
else if (funcDef.nameIsProduct()) {
351 return structFuncDefError(funcDef.getOperation())
355 if (foundConstrain) {
356 return structFuncDefError(funcDef.getOperation())
361 return structFuncDefError(funcDef.getOperation())
364 foundProduct = std::make_optional(funcDef);
368 return structFuncDefError(funcDef.getOperation())
369 <<
"found \"@" << funcDef.getSymName() <<
'"';
375 <<
"' operations are permitted";
380 if (!foundCompute.has_value() && foundConstrain.has_value()) {
384 if (!foundConstrain.has_value() && foundCompute.has_value()) {
390 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
391 return structFuncDefError(getOperation())
396 if (foundCompute && foundConstrain) {
397 return verifyStructComputeConstrain(*
this, *foundCompute, *foundConstrain);
399 return verifyStructProduct(*
this, *foundProduct);
403 for (Operation &op : *getBody()) {
404 if (
FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
405 if (fieldName.compare(fieldDef.getSymNameAttr()) == 0) {
414 std::vector<FieldDefOp> res;
415 for (Operation &op : *getBody()) {
416 if (
FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
417 res.push_back(fieldDef);
433 return llvm::dyn_cast<FuncDefOp>(computeFunc);
440 return llvm::dyn_cast<FuncDefOp>(constrainFunc);
452 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
459 props.column = odsBuilder.getUnitAttr();
464 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type,
bool isColumn
466 build(odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isColumn);
470 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
471 ArrayRef<NamedAttribute> attributes,
bool isColumn
473 assert(operands.size() == 0u &&
"mismatched number of parameters");
474 odsState.addOperands(operands);
475 odsState.addAttributes(attributes);
476 assert(resultTypes.size() == 0u &&
"mismatched number of return types");
477 odsState.addTypes(resultTypes);
479 odsState.getOrAddProperties<
Properties>().column = odsBuilder.getUnitAttr();
485 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
487 getOperation()->removeAttr(PublicAttr::name);
492verifyFieldDefTypeImpl(Type fieldType, SymbolTableCollection &tables, Operation *origin) {
493 if (
StructType fieldStructType = llvm::dyn_cast<StructType>(fieldType)) {
497 if (failed(fieldTypeRes)) {
501 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
502 if (fieldTypeRes.value() == parentRes.value()) {
503 return origin->emitOpError()
504 .append(
"type is circular")
505 .attachNote(parentRes.value().getLoc())
506 .append(
"references parent component defined here");
515 Type fieldType = this->
getType();
516 if (failed(verifyFieldDefTypeImpl(fieldType, tables, *
this))) {
525 return emitOpError() <<
"marked as column can only contain felts, arrays of column types, or "
526 "structs with columns, but field has type "
537FailureOr<SymbolLookupResult<FieldDefOp>>
539 Operation *op = refOp.getOperation();
541 if (failed(structDefRes)) {
545 tables, SymbolRefAttr::get(refOp->getContext(), refOp.
getFieldName()),
546 std::move(*structDefRes), op
553 return std::move(res.value());
556static FailureOr<SymbolLookupResult<FieldDefOp>>
564 return getFieldDefOpImpl(refOp, tables, tyStruct);
567static LogicalResult verifySymbolUsesImpl(
568 FieldRefOpInterface refOp, SymbolTableCollection &tables, SymbolLookupResult<FieldDefOp> &field
571 Type actualType = refOp.
getVal().getType();
572 Type fieldType = field.
get().getType();
574 return refOp->emitOpError() <<
"has wrong type; expected " << fieldType <<
", got "
581LogicalResult verifySymbolUsesImpl(
FieldRefOpInterface refOp, SymbolTableCollection &tables) {
583 auto field = findField(refOp, tables);
587 return verifySymbolUsesImpl(refOp, tables, *field);
592FailureOr<SymbolLookupResult<FieldDefOp>>
598 auto field = findField(*
this, tables);
602 if (failed(verifySymbolUsesImpl(*
this, tables, *field))) {
607 return emitOpError(
"cannot read with table offset from a field that is not a column")
608 .attachNote(field->
get().getLoc())
609 .append(
"field defined here");
618 if (failed(getParentRes)) {
625 return verifySymbolUsesImpl(*
this, tables);
633 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr field
637 state.addTypes(resultType);
643 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr field,
644 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
647 assert(mapOperands.empty() || numDims.has_value());
649 state.addTypes(resultType);
650 if (numDims.has_value()) {
652 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
658 props.setFieldName(FlatSymbolRefAttr::get(field));
659 props.setTableOffset(dist);
663 OpBuilder & , OperationState &odsState, TypeRange resultTypes,
664 ValueRange operands, ArrayRef<NamedAttribute> attrs
666 odsState.addTypes(resultTypes);
667 odsState.addOperands(operands);
668 odsState.addAttributes(attrs);
672 SmallVector<AffineMapAttr, 1> mapAttrs;
673 if (AffineMapAttr map =
674 llvm::dyn_cast_if_present<AffineMapAttr>(
getTableOffset().value_or(
nullptr))) {
675 mapAttrs.push_back(map);
692 if (failed(getParentRes)) {
695 if (failed(
checkSelfType(tables, *getParentRes, this->getType(), *
this,
"result"))) {
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
static constexpr ::llvm::StringLiteral name
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::TypedValue<::llzk::component::StructType > getResult()
FoldAdaptor::Properties Properties
void setPublicAttr(bool newValue=true)
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isColumn=false)
static constexpr ::llvm::StringLiteral getOperationName()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
FoldAdaptor::Properties Properties
::mlir::OperandRangeRange getMapOperands()
::llvm::LogicalResult verify()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr field)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::std::optional<::mlir::Attribute > getTableOffset()
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::llvm::StringRef getFieldName()
Gets the field name attribute value from the FieldRefOp.
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
FieldDefOp getFieldDef(::mlir::StringAttr fieldName)
Gets the FieldDefOp that defines the field in this structure with the given name, if present.
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static constexpr ::llvm::StringLiteral getOperationName()
::llzk::function::FuncDefOp getConstrainOrProductFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
::llvm::StringRef getSymName()
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeOrProductFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::llvm::LogicalResult verifyRegions()
::mlir::ArrayAttr getConstParamsAttr()
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool isMainComponent()
Return true iff this StructDefOp is named "Main".
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
static constexpr ::llvm::StringLiteral name
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
::mlir::FunctionType getFunctionType()
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
::llvm::StringRef getSymName()
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
static constexpr ::llvm::StringLiteral getOperationName()
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
FailureOr< StructDefOp > verifyInStruct(Operation *op)
bool isInStructFunctionNamed(Operation *op, char const *funcName)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
bool isSignalType(Type type)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
void setFieldName(const ::mlir::FlatSymbolRefAttr &propValue)