LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Global value operation implementations --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
17
18// TableGen'd implementation files
20
21// TableGen'd implementation files
22#define GET_OP_CLASSES
24
25using namespace mlir;
26using namespace llzk::array;
27using namespace llzk::felt;
28using namespace llzk::string;
29
30namespace llzk::global {
31
32//===------------------------------------------------------------------===//
33// GlobalDefOp
34//===------------------------------------------------------------------===//
35
37 OpAsmParser &parser, Attribute &initialValue, TypeAttr typeAttr
38) {
39 if (parser.parseOptionalEqual()) {
40 // When there's no equal sign, there's no initial value to parse.
41 return success();
42 }
43 Type specifiedType = typeAttr.getValue();
44
45 // Special case for parsing LLZK FeltType to match format of FeltConstantOp.
46 // Not actually necessary but the default format is verbose. ex: "#llzk<felt.const 35>"
47 if (isa<FeltType>(specifiedType)) {
48 FeltConstAttr feltConstAttr;
49 if (parser.parseCustomAttributeWithFallback<FeltConstAttr>(feltConstAttr)) {
50 return failure();
51 }
52 initialValue = feltConstAttr;
53 return success();
54 }
55 // Fallback to default parser for all other types.
56 if (failed(parser.parseAttribute(initialValue, specifiedType))) {
57 return failure();
58 }
59 return success();
60}
61
63 OpAsmPrinter &p, GlobalDefOp op, Attribute initialValue, TypeAttr typeAttr
64) {
65 if (initialValue) {
66 p << " = ";
67 // Special case for LLZK FeltType to match format of FeltConstantOp.
68 // Not actually necessary but the default format is verbose. ex: "#llzk<felt.const 35>"
69 if (FeltConstAttr feltConstAttr = llvm::dyn_cast<FeltConstAttr>(initialValue)) {
70 p.printStrippedAttrOrType<FeltConstAttr>(feltConstAttr);
71 } else {
72 p.printAttributeWithoutType(initialValue);
73 }
74 }
75}
76
77LogicalResult GlobalDefOp::verifySymbolUses(SymbolTableCollection &tables) {
78 // Ensure any SymbolRef used in the type are valid
79 return verifyTypeResolution(tables, *this, getType());
80}
81
82namespace {
83
84inline InFlightDiagnostic reportMismatch(
85 EmitErrorFn errFn, Type rootType, const Twine &aspect, const Twine &expected, const Twine &found
86) {
87 return errFn().append(
88 "with type ", rootType, " expected ", expected, " ", aspect, " but found ", found
89 );
90}
91
92inline InFlightDiagnostic reportMismatch(
93 EmitErrorFn errFn, Type rootType, const Twine &aspect, const Twine &expected, Attribute found
94) {
95 return reportMismatch(errFn, rootType, aspect, expected, found.getAbstractAttribute().getName());
96}
97
98LogicalResult ensureAttrTypeMatch(
99 Type type, Attribute valAttr, const OwningEmitErrorFn &errFn, Type rootType, const Twine &aspect
100) {
101 if (!isValidGlobalType(type)) {
102 // Same error message ODS-generated code would produce
103 return errFn().append("attribute 'type' failed to satisfy constraint: type attribute of "
104 "any LLZK type except non-constant types");
105 }
106 if (type.isSignlessInteger(1)) {
107 if (IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(valAttr)) {
108 APInt val = ia.getValue();
109 if (!val.isZero() && !val.isOne()) {
110 return errFn().append("integer constant out of range for attribute");
111 }
112 } else if (!llvm::isa<BoolAttr>(valAttr)) {
113 return reportMismatch(errFn, rootType, aspect, "builtin.bool or builtin.integer", valAttr);
114 }
115 } else if (llvm::isa<IndexType>(type)) {
116 // The explicit check for BoolAttr is needed because the LLVM isa/cast functions treat
117 // BoolAttr as a subtype of IntegerAttr but this scenario should not allow BoolAttr.
118 bool isBool = llvm::isa<BoolAttr>(valAttr);
119 if (isBool || !llvm::isa<IntegerAttr>(valAttr)) {
120 return reportMismatch(
121 errFn, rootType, aspect, "builtin.index",
122 isBool ? "builtin.bool" : valAttr.getAbstractAttribute().getName()
123 );
124 }
125 } else if (llvm::isa<FeltType>(type)) {
126 if (!llvm::isa<FeltConstAttr, IntegerAttr>(valAttr)) {
127 return reportMismatch(errFn, rootType, aspect, "felt.type", valAttr);
128 }
129 } else if (llvm::isa<StringType>(type)) {
130 if (!llvm::isa<StringAttr>(valAttr)) {
131 return reportMismatch(errFn, rootType, aspect, "builtin.string", valAttr);
132 }
133 } else if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
134 if (ArrayAttr arrVal = llvm::dyn_cast<ArrayAttr>(valAttr)) {
135 // Ensure the number of elements is correct for the ArrayType
136 assert(arrTy.hasStaticShape() && "implied by earlier isValidGlobalType() check");
137 int64_t expectedCount = arrTy.getNumElements();
138 size_t actualCount = arrVal.size();
139 if (std::cmp_not_equal(actualCount, expectedCount)) {
140 return reportMismatch(
141 errFn, rootType, Twine(aspect) + " to contain " + Twine(expectedCount) + " elements",
142 "builtin.array", Twine(actualCount)
143 );
144 }
145 // Ensure the type of each element is correct for the ArrayType.
146 // Rather than immediately returning on failure, check all elements and aggregate to provide
147 // as many errors are possible in a single verifier run.
148 bool hasFailure = false;
149 Type expectedElemTy = arrTy.getElementType();
150 for (Attribute e : arrVal.getValue()) {
151 hasFailure |=
152 failed(ensureAttrTypeMatch(expectedElemTy, e, errFn, rootType, "array element"));
153 }
154 if (hasFailure) {
155 return failure();
156 }
157 } else {
158 return reportMismatch(errFn, rootType, aspect, "builtin.array", valAttr);
159 }
160 } else {
161 return errFn().append("expected a valid LLZK type but found ", type);
162 }
163 return success();
164}
165
166} // namespace
167
168LogicalResult GlobalDefOp::verify() {
169 if (Attribute initValAttr = getInitialValueAttr()) {
170 Type ty = getType();
171 OwningEmitErrorFn errFn = getEmitOpErrFn(this);
172 return ensureAttrTypeMatch(ty, initValAttr, errFn, ty, "attribute value");
173 }
174 // If there is no initial value, it cannot have "const".
175 if (isConstant()) {
176 return emitOpError("marked as 'const' must be assigned a value");
177 }
178 return success();
179}
180
181//===------------------------------------------------------------------===//
182// GlobalReadOp / GlobalWriteOp
183//===------------------------------------------------------------------===//
184
185FailureOr<SymbolLookupResult<GlobalDefOp>>
186GlobalRefOpInterface::getGlobalDefOp(SymbolTableCollection &tables) {
187 return lookupTopLevelSymbol<GlobalDefOp>(tables, getNameRef(), getOperation());
188}
189
190namespace {
191
192FailureOr<SymbolLookupResult<GlobalDefOp>>
193verifySymbolUsesImpl(GlobalRefOpInterface refOp, SymbolTableCollection &tables) {
194 // Ensure this op references a valid GlobalDefOp name
195 auto tgt = refOp.getGlobalDefOp(tables);
196 if (failed(tgt)) {
197 return failure();
198 }
199 // Ensure the SSA Value type matches the GlobalDefOp type
200 Type globalType = tgt->get().getType();
201 if (!typesUnify(refOp.getVal().getType(), globalType, tgt->getIncludeSymNames())) {
202 return refOp->emitOpError() << "has wrong type; expected " << globalType << ", got "
203 << refOp.getVal().getType();
204 }
205 return tgt;
206}
207
208} // namespace
209
210LogicalResult GlobalReadOp::verifySymbolUses(SymbolTableCollection &tables) {
211 if (failed(verifySymbolUsesImpl(*this, tables))) {
212 return failure();
213 }
214 // Ensure any SymbolRef used in the type are valid
215 return verifyTypeResolution(tables, *this, getType());
216}
217
218LogicalResult GlobalWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
219 auto tgt = verifySymbolUsesImpl(*this, tables);
220 if (failed(tgt)) {
221 return failure();
222 }
223 if (tgt->get().isConstant()) {
224 return emitOpError().append(
225 "cannot target '", GlobalDefOp::getOperationName(), "' marked as 'const'"
226 );
227 }
228 return success();
229}
230
231} // namespace llzk::global
::mlir::Type getType()
Definition Ops.cpp.inc:475
::mlir::ParseResult parseGlobalInitialValue(::mlir::OpAsmParser &parser, ::mlir::Attribute &initialValue, ::mlir::TypeAttr typeAttr)
Definition Ops.cpp:36
static void printGlobalInitialValue(::mlir::OpAsmPrinter &printer, GlobalDefOp op, ::mlir::Attribute initialValue, ::mlir::TypeAttr typeAttr)
Definition Ops.cpp:62
::mlir::Attribute getInitialValueAttr()
Definition Ops.cpp.inc:480
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:77
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:208
::mlir::LogicalResult verify()
Definition Ops.cpp:168
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:210
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the GlobalRefOp.
::mlir::FailureOr< SymbolLookupResult< GlobalDefOp > > getGlobalDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the global referenced in this op.
Definition Ops.cpp:186
::mlir::SymbolRefAttr getNameRef()
Gets the global name attribute from the GlobalRefOp.
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:218
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
llvm::function_ref< mlir::InFlightDiagnostic()> EmitErrorFn
Definition ErrorHelper.h:18
bool isValidGlobalType(Type type)
std::function< mlir::InFlightDiagnostic()> OwningEmitErrorFn
Definition ErrorHelper.h:22
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
Definition ErrorHelper.h:24
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)