18#include <mlir/IR/IRMapping.h>
19#include <mlir/IR/OpImplementation.h>
21#include <llvm/ADT/MapVector.h>
22#include <llvm/ADT/StringSet.h>
50 if (succeeded(parentFuncOpt)) {
51 FuncDefOp parentFunc = parentFuncOpt.value();
53 if (parentFunc.
getSymName().compare(funcName) == 0) {
63 assert(llvm::isa<StructDefOp>(structOp));
64 llvm::cast<StructDefOp>(structOp).getBody().walk([](
FuncDefOp funcDef) {
77 std::string prefix = std::string();
78 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
80 prefix += symbol.getName();
83 return origin->emitOpError().append(
92 SymbolTableCollection &tables,
StructDefOp &expectedStruct, Type actualType, Operation *origin,
95 if (
StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
96 auto actualStructOpt =
98 if (failed(actualStructOpt)) {
99 return origin->emitError().append(
101 actualStructType.getNameRef(),
"\""
104 StructDefOp actualStruct = actualStructOpt.value().get();
105 if (actualStruct != expectedStruct) {
107 .attachNote(actualStruct.getLoc())
108 .append(
"uses this type instead");
115 if (ArrayAttr tyParams = actualStructType.getParams()) {
116 if (failed(
verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
122 .attachNote(actualStruct.getLoc())
136inline LogicalResult msgOneFunction(
EmitErrorFn emitError,
const Twine &name) {
137 return emitError() <<
"must define exactly one '" << name <<
"' function";
144 assert(succeeded(pathRes));
151 if (succeeded(pathToExpected)) {
152 ss << pathToExpected.value();
159 ss <<
'<' << attr <<
'>';
166 for (Attribute attr : params) {
167 assert(llvm::isa<FlatSymbolRefAttr>(attr));
168 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
178 assert(succeeded(res));
185 llvm::StringSet<> uniqNames;
186 for (Attribute attr : params) {
187 assert(llvm::isa<FlatSymbolRefAttr>(attr));
188 StringRef name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
189 if (!uniqNames.insert(name).second) {
190 return this->emitOpError().append(
"has more than one parameter named \"@", name,
"\"");
194 for (Attribute attr : params) {
196 if (succeeded(res)) {
197 return this->emitOpError()
198 .append(
"parameter name \"@")
199 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
200 .append(
"\" conflicts with an existing symbol")
201 .attachNote(res->get()->getLoc())
202 .append(
"symbol already defined here");
211inline LogicalResult checkMainFuncParamType(Type pType,
FuncDefOp inFunc,
bool appendSelf) {
214 }
else if (
auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
222 <<
"\" function parameters must be one of: {";
230 return inFunc.emitError(message);
236 assert(
getBody().hasOneBlock());
237 std::optional<FuncDefOp> foundCompute = std::nullopt;
238 std::optional<FuncDefOp> foundConstrain = std::nullopt;
244 for (Operation &op :
getBody().front()) {
245 if (!llvm::isa<FieldDefOp>(op)) {
246 if (
FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
247 if (funcDef.nameIsCompute()) {
251 foundCompute = std::make_optional(funcDef);
252 }
else if (funcDef.nameIsConstrain()) {
253 if (foundConstrain) {
256 foundConstrain = std::make_optional(funcDef);
260 return op.emitError() <<
"'" <<
getOperationName() <<
"' op " <<
"must define only \"@"
262 <<
"\" functions;" <<
" found \"@" << funcDef.getSymName()
269 <<
"' operations are permitted";
273 if (!foundCompute.has_value()) {
276 if (!foundConstrain.has_value()) {
282 assert(foundConstrain->hasAllowConstraintAttr());
283 assert(!foundCompute->hasAllowConstraintAttr());
284 assert(!foundConstrain->hasAllowWitnessAttr());
285 assert(foundCompute->hasAllowWitnessAttr());
289 ArrayRef<Type> computeParams = foundCompute->getFunctionType().getInputs();
290 ArrayRef<Type> constrainParams = foundConstrain->getFunctionType().getInputs().drop_front();
294 return this->emitError().append(
301 for (Type t : computeParams) {
302 if (failed(checkMainFuncParamType(t, *foundCompute,
false))) {
306 for (Type t : constrainParams) {
307 if (failed(checkMainFuncParamType(t, *foundConstrain,
true))) {
315 return foundConstrain->emitError()
318 "\" function argument types (sans the first one) to match \"@",
FUNC_NAME_COMPUTE,
319 "\" function argument types"
321 .attachNote(foundCompute->getLoc())
329 assert(
getBody().hasOneBlock());
331 for (Operation &op :
getBody().front()) {
332 if (
FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
333 if (fieldName.compare(fieldDef.getSymNameAttr()) == 0) {
342 assert(
getBody().hasOneBlock());
344 std::vector<FieldDefOp> res;
345 for (Operation &op :
getBody().front()) {
346 if (
FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
347 res.push_back(fieldDef);
368 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
375 props.column = odsBuilder.getUnitAttr();
380 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type,
bool isColumn
382 build(odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isColumn);
386 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
387 ArrayRef<NamedAttribute> attributes,
bool isColumn
389 assert(operands.size() == 0u &&
"mismatched number of parameters");
390 odsState.addOperands(operands);
391 odsState.addAttributes(attributes);
392 assert(resultTypes.size() == 0u &&
"mismatched number of return types");
393 odsState.addTypes(resultTypes);
395 odsState.getOrAddProperties<
Properties>().column = odsBuilder.getUnitAttr();
401 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
403 getOperation()->removeAttr(PublicAttr::name);
408verifyFieldDefTypeImpl(Type fieldType, SymbolTableCollection &tables, Operation *origin) {
409 if (
StructType fieldStructType = llvm::dyn_cast<StructType>(fieldType)) {
413 if (failed(fieldTypeRes)) {
417 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
418 if (fieldTypeRes.value() == parentRes.value()) {
419 return origin->emitOpError()
420 .append(
"type is circular")
421 .attachNote(parentRes.value().getLoc())
422 .append(
"references parent component defined here");
431 Type fieldType = this->
getType();
432 if (failed(verifyFieldDefTypeImpl(fieldType, tables, *
this))) {
441 return emitOpError() <<
"marked as column can only contain felts, arrays of column types, or "
442 "structs with columns, but field has type "
453FailureOr<SymbolLookupResult<FieldDefOp>>
455 Operation *op = refOp.getOperation();
457 if (failed(structDefRes)) {
461 tables, SymbolRefAttr::get(refOp->getContext(), refOp.
getFieldName()),
462 std::move(*structDefRes), op
469 return std::move(res.value());
472static FailureOr<SymbolLookupResult<FieldDefOp>>
480 return getFieldDefOpImpl(refOp, tables, tyStruct);
483static LogicalResult verifySymbolUsesImpl(
484 FieldRefOpInterface refOp, SymbolTableCollection &tables, SymbolLookupResult<FieldDefOp> &field
487 Type actualType = refOp.
getVal().getType();
488 Type fieldType = field.
get().getType();
490 return refOp->emitOpError() <<
"has wrong type; expected " << fieldType <<
", got "
497LogicalResult verifySymbolUsesImpl(
FieldRefOpInterface refOp, SymbolTableCollection &tables) {
499 auto field = findField(refOp, tables);
503 return verifySymbolUsesImpl(refOp, tables, *field);
508FailureOr<SymbolLookupResult<FieldDefOp>>
514 auto field = findField(*
this, tables);
518 if (failed(verifySymbolUsesImpl(*
this, tables, *field))) {
523 return emitOpError(
"cannot read with table offset from a field that is not a column")
524 .attachNote(field->
get().getLoc())
525 .append(
"field defined here");
534 if (failed(getParentRes)) {
541 return verifySymbolUsesImpl(*
this, tables);
549 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr field
553 state.addTypes(resultType);
559 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr field,
560 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
562 assert(numDims.has_value() != mapOperands.empty());
564 state.addTypes(resultType);
565 if (numDims.has_value()) {
567 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
573 props.setFieldName(FlatSymbolRefAttr::get(field));
574 props.setTableOffset(dist);
578 OpBuilder &builder, OperationState &state, TypeRange resultTypes, ValueRange operands,
579 ArrayRef<NamedAttribute> attrs
581 state.addTypes(resultTypes);
582 state.addOperands(operands);
583 state.addAttributes(attrs);
587 SmallVector<AffineMapAttr, 1> mapAttrs;
588 if (AffineMapAttr map =
589 llvm::dyn_cast_if_present<AffineMapAttr>(
getTableOffset().value_or(
nullptr))) {
590 mapAttrs.push_back(map);
607 if (failed(getParentRes)) {
610 if (failed(
checkSelfType(tables, *getParentRes, this->getType(), *
this,
"result"))) {
std::vector< llvm::StringRef > getIncludeSymNames()
static constexpr ::llvm::StringLiteral name
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::TypedValue<::llzk::component::StructType > getResult()
FoldAdaptor::Properties Properties
void setPublicAttr(bool newValue=true)
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isColumn=false)
static constexpr ::llvm::StringLiteral getOperationName()
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
FoldAdaptor::Properties Properties
::mlir::LogicalResult verify()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr field)
::mlir::OperandRangeRange getMapOperands()
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::std::optional<::mlir::Attribute > getTableOffset()
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::llvm::StringRef getFieldName()
Gets the field name attribute value from the FieldRefOp.
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
::mlir::Region & getBody()
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
FieldDefOp getFieldDef(::mlir::StringAttr fieldName)
Gets the FieldDefOp that defines the field in this structure with the given name, if present.
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::mlir::ArrayAttr getConstParamsAttr()
static constexpr ::llvm::StringLiteral getOperationName()
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
::llvm::StringRef getSymName()
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::mlir::LogicalResult verifyRegions()
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
bool isMainComponent()
Return true iff this StructDefOp is named "Main".
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
static constexpr ::llvm::StringLiteral name
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
::llvm::StringRef getSymName()
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
static constexpr ::llvm::StringLiteral getOperationName()
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
InFlightDiagnostic genCompareErr(StructDefOp &expected, Operation *origin, const char *aspect)
bool isInStruct(Operation *op)
FailureOr< StructDefOp > verifyInStruct(Operation *op)
bool isInStructFunctionNamed(Operation *op, char const *funcName)
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp &expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e.
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
FailureOr< SymbolRefAttr > getPathFromRoot(StructDefOp &to)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
llvm::function_ref< mlir::InFlightDiagnostic()> EmitErrorFn
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isNullOrEmpty(mlir::ArrayAttr a)
std::function< mlir::InFlightDiagnostic()> OwningEmitErrorFn
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
bool isSignalType(Type type)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
void setSymName(const ::mlir::StringAttr &propValue)
void setFieldName(const ::mlir::FlatSymbolRefAttr &propValue)