LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Global value operation implementations --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
17
18// TableGen'd implementation files
20
21// TableGen'd implementation files
22#define GET_OP_CLASSES
24
25using namespace mlir;
26using namespace llzk::array;
27using namespace llzk::felt;
28using namespace llzk::string;
29
30namespace llzk::global {
31
32//===------------------------------------------------------------------===//
33// GlobalDefOp
34//===------------------------------------------------------------------===//
35
37 OpAsmParser &parser, Attribute &initialValue, TypeAttr typeAttr
38) {
39 if (parser.parseOptionalEqual()) {
40 // When there's no equal sign, there's no initial value to parse.
41 return success();
42 }
43 Type specifiedType = typeAttr.getValue();
44
45 // Special case for parsing LLZK FeltType to match format of FeltConstantOp.
46 // Not actually necessary but the default format is verbose. ex: "#llzk<felt.const 35>"
47 if (isa<FeltType>(specifiedType)) {
48 FeltConstAttr feltConstAttr;
49 if (parser.parseCustomAttributeWithFallback<FeltConstAttr>(feltConstAttr)) {
50 return failure();
51 }
52 initialValue = feltConstAttr;
53 return success();
54 }
55 // Fallback to default parser for all other types.
56 if (failed(parser.parseAttribute(initialValue, specifiedType))) {
57 return failure();
58 }
59 return success();
60}
61
63 OpAsmPrinter &p, GlobalDefOp /*op*/, Attribute initialValue, TypeAttr typeAttr
64) {
65 if (initialValue) {
66 p << " = ";
67 // Special case for LLZK FeltType to match format of FeltConstantOp.
68 // Not actually necessary but the default format is verbose. ex: "#llzk<felt.const 35>"
69 if (FeltConstAttr feltConstAttr = llvm::dyn_cast<FeltConstAttr>(initialValue)) {
70 p.printStrippedAttrOrType<FeltConstAttr>(feltConstAttr);
71 } else {
72 p.printAttributeWithoutType(initialValue);
73 }
74 }
75}
76
77LogicalResult GlobalDefOp::verifySymbolUses(SymbolTableCollection &tables) {
78 // Ensure any SymbolRef used in the type are valid
79 return verifyTypeResolution(tables, *this, getType());
80}
81
82namespace {
83
84inline InFlightDiagnosticWrapper reportMismatch(
85 EmitErrorFn errFn, Type rootType, const Twine &aspect, const Twine &expected, const Twine &found
86) {
87 return errFn().append(
88 "with type ", rootType, " expected ", expected, " ", aspect, " but found ", found
89 );
90}
91
92inline InFlightDiagnosticWrapper reportMismatch(
93 EmitErrorFn errFn, Type rootType, const Twine &aspect, const Twine &expected, Attribute found
94) {
95 return reportMismatch(errFn, rootType, aspect, expected, found.getAbstractAttribute().getName());
96}
97
98LogicalResult ensureAttrTypeMatch(
99 Type type, Attribute valAttr, const OwningEmitErrorFn &errFn, Type rootType, const Twine &aspect
100) {
101 if (!isValidGlobalType(type)) {
102 // Same error message ODS-generated code would produce
103 return errFn().append(
104 "attribute 'type' failed to satisfy constraint: type attribute of "
105 "any LLZK type except non-constant types"
106 );
107 }
108 if (type.isSignlessInteger(1)) {
109 if (IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(valAttr)) {
110 APInt val = ia.getValue();
111 if (!val.isZero() && !val.isOne()) {
112 return errFn().append("integer constant out of range for attribute");
113 }
114 } else if (!llvm::isa<BoolAttr>(valAttr)) {
115 return reportMismatch(errFn, rootType, aspect, "builtin.bool or builtin.integer", valAttr);
116 }
117 } else if (llvm::isa<IndexType>(type)) {
118 // The explicit check for BoolAttr is needed because the LLVM isa/cast functions treat
119 // BoolAttr as a subtype of IntegerAttr but this scenario should not allow BoolAttr.
120 bool isBool = llvm::isa<BoolAttr>(valAttr);
121 if (isBool || !llvm::isa<IntegerAttr>(valAttr)) {
122 return reportMismatch(
123 errFn, rootType, aspect, "builtin.index",
124 isBool ? "builtin.bool" : valAttr.getAbstractAttribute().getName()
125 );
126 }
127 } else if (llvm::isa<FeltType>(type)) {
128 if (!llvm::isa<FeltConstAttr, IntegerAttr>(valAttr)) {
129 return reportMismatch(errFn, rootType, aspect, "felt.type", valAttr);
130 }
131 } else if (llvm::isa<StringType>(type)) {
132 if (!llvm::isa<StringAttr>(valAttr)) {
133 return reportMismatch(errFn, rootType, aspect, "builtin.string", valAttr);
134 }
135 } else if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
136 if (ArrayAttr arrVal = llvm::dyn_cast<ArrayAttr>(valAttr)) {
137 // Ensure the number of elements is correct for the ArrayType
138 assert(arrTy.hasStaticShape() && "implied by earlier isValidGlobalType() check");
139 int64_t expectedCount = arrTy.getNumElements();
140 size_t actualCount = arrVal.size();
141 if (std::cmp_not_equal(actualCount, expectedCount)) {
142 return reportMismatch(
143 errFn, rootType, Twine(aspect) + " to contain " + Twine(expectedCount) + " elements",
144 "builtin.array", Twine(actualCount)
145 );
146 }
147 // Ensure the type of each element is correct for the ArrayType.
148 // Rather than immediately returning on failure, check all elements and aggregate to provide
149 // as many errors are possible in a single verifier run.
150 bool hasFailure = false;
151 Type expectedElemTy = arrTy.getElementType();
152 for (Attribute e : arrVal.getValue()) {
153 hasFailure |=
154 failed(ensureAttrTypeMatch(expectedElemTy, e, errFn, rootType, "array element"));
155 }
156 if (hasFailure) {
157 return failure();
158 }
159 } else {
160 return reportMismatch(errFn, rootType, aspect, "builtin.array", valAttr);
161 }
162 } else {
163 return errFn().append("expected a valid LLZK type but found ", type);
164 }
165 return success();
166}
167
168} // namespace
169
170LogicalResult GlobalDefOp::verify() {
171 if (Attribute initValAttr = getInitialValueAttr()) {
172 Type ty = getType();
173 OwningEmitErrorFn errFn = getEmitOpErrFn(this);
174 return ensureAttrTypeMatch(ty, initValAttr, errFn, ty, "attribute value");
175 }
176 // If there is no initial value, it cannot have "const".
177 if (isConstant()) {
178 return emitOpError("marked as 'const' must be assigned a value");
179 }
180 return success();
181}
182
183//===------------------------------------------------------------------===//
184// GlobalReadOp / GlobalWriteOp
185//===------------------------------------------------------------------===//
186
187FailureOr<SymbolLookupResult<GlobalDefOp>>
188GlobalRefOpInterface::getGlobalDefOp(SymbolTableCollection &tables) {
189 return lookupTopLevelSymbol<GlobalDefOp>(tables, getNameRef(), getOperation());
190}
191
192namespace {
193
194FailureOr<SymbolLookupResult<GlobalDefOp>>
195verifySymbolUsesImpl(GlobalRefOpInterface refOp, SymbolTableCollection &tables) {
196 // Ensure this op references a valid GlobalDefOp name
197 auto tgt = refOp.getGlobalDefOp(tables);
198 if (failed(tgt)) {
199 return failure();
200 }
201 // Ensure the SSA Value type matches the GlobalDefOp type
202 Type globalType = tgt->get().getType();
203 if (!typesUnify(refOp.getVal().getType(), globalType, tgt->getIncludeSymNames())) {
204 return refOp->emitOpError() << "has wrong type; expected " << globalType << ", got "
205 << refOp.getVal().getType();
206 }
207 return tgt;
208}
209
210} // namespace
211
212LogicalResult GlobalReadOp::verifySymbolUses(SymbolTableCollection &tables) {
213 if (failed(verifySymbolUsesImpl(*this, tables))) {
214 return failure();
215 }
216 // Ensure any SymbolRef used in the type are valid
217 return verifyTypeResolution(tables, *this, getType());
218}
219
220LogicalResult GlobalWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
221 auto tgt = verifySymbolUsesImpl(*this, tables);
222 if (failed(tgt)) {
223 return failure();
224 }
225 if (tgt->get().isConstant()) {
226 return emitOpError().append(
227 "cannot target '", GlobalDefOp::getOperationName(), "' marked as 'const'"
228 );
229 }
230 return success();
231}
232
233} // namespace llzk::global
Wrapper around InFlightDiagnostic that can either be a regular InFlightDiagnostic or a special versio...
Definition ErrorHelper.h:26
InFlightDiagnosticWrapper & append(Args &&...args) &
Append arguments to the diagnostic.
Definition ErrorHelper.h:83
::mlir::Type getType()
Definition Ops.cpp.inc:401
::mlir::ParseResult parseGlobalInitialValue(::mlir::OpAsmParser &parser, ::mlir::Attribute &initialValue, ::mlir::TypeAttr typeAttr)
Definition Ops.cpp:36
static void printGlobalInitialValue(::mlir::OpAsmPrinter &printer, GlobalDefOp op, ::mlir::Attribute initialValue, ::mlir::TypeAttr typeAttr)
Definition Ops.cpp:62
::mlir::Attribute getInitialValueAttr()
Definition Ops.h.inc:267
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:77
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:219
::llvm::LogicalResult verify()
Definition Ops.cpp:170
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:212
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the GlobalRefOp.
::mlir::FailureOr< SymbolLookupResult< GlobalDefOp > > getGlobalDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the global referenced in this op.
Definition Ops.cpp:188
::mlir::SymbolRefAttr getNameRef()
Gets the global name attribute from the GlobalRefOp.
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:220
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
bool isValidGlobalType(Type type)
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)