LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Struct op implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Constants.h"
18
19#include <mlir/IR/IRMapping.h>
20#include <mlir/IR/OpImplementation.h>
21
22#include <llvm/ADT/MapVector.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/StringSet.h>
25
26#include <optional>
27
28// TableGen'd implementation files
30
31// TableGen'd implementation files
32#define GET_OP_CLASSES
34
35using namespace mlir;
36using namespace llzk::array;
37using namespace llzk::function;
38
39namespace llzk::component {
40
41bool isInStruct(Operation *op) { return succeeded(getParentOfType<StructDefOp>(op)); }
42
43FailureOr<StructDefOp> verifyInStruct(Operation *op) {
44 FailureOr<StructDefOp> res = getParentOfType<StructDefOp>(op);
45 if (failed(res)) {
46 return op->emitOpError() << "only valid within a '" << StructDefOp::getOperationName()
47 << "' ancestor";
48 }
49 return res;
50}
51
52bool isInStructFunctionNamed(Operation *op, char const *funcName) {
53 FailureOr<FuncDefOp> parentFuncOpt = getParentOfType<FuncDefOp>(op);
54 if (succeeded(parentFuncOpt)) {
55 FuncDefOp parentFunc = parentFuncOpt.value();
56 if (isInStruct(parentFunc.getOperation())) {
57 if (parentFunc.getSymName().compare(funcName) == 0) {
58 return true;
59 }
60 }
61 }
62 return false;
63}
64
65// Again, only valid/implemented for StructDefOp
66template <> LogicalResult SetFuncAllowAttrs<StructDefOp>::verifyTrait(Operation *structOp) {
67 assert(llvm::isa<StructDefOp>(structOp));
68 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
69 if (!bodyRegion.empty()) {
70 bodyRegion.front().walk([](FuncDefOp funcDef) {
71 if (funcDef.nameIsConstrain()) {
72 funcDef.setAllowConstraintAttr();
73 funcDef.setAllowWitnessAttr(false);
74 } else if (funcDef.nameIsCompute()) {
75 funcDef.setAllowConstraintAttr(false);
76 funcDef.setAllowWitnessAttr();
77 } else if (funcDef.nameIsProduct()) {
78 funcDef.setAllowConstraintAttr();
79 funcDef.setAllowWitnessAttr();
80 }
81 });
82 }
83 return success();
84}
85
86InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect) {
87 std::string prefix = std::string();
88 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
89 prefix += "\"@";
90 prefix += symbol.getName();
91 prefix += "\" ";
92 }
93 return origin->emitOpError().append(
94 prefix, "must use type of its ancestor '", StructDefOp::getOperationName(), "' \"",
95 expected.getHeaderString(), "\" as ", aspect, " type"
96 );
97}
98
99static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
100 return origin->emitError() << '\'' << StructDefOp::getOperationName() << "' op "
101 << "must define either only a \"@" << FUNC_NAME_PRODUCT
102 << "\" function, or both \"@" << FUNC_NAME_COMPUTE << "\" and \"@"
103 << FUNC_NAME_CONSTRAIN << "\" functions; ";
104}
105
108LogicalResult checkSelfType(
109 SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin,
110 const char *aspect
111) {
112 if (StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
113 auto actualStructOpt =
114 lookupTopLevelSymbol<StructDefOp>(tables, actualStructType.getNameRef(), origin);
115 if (failed(actualStructOpt)) {
116 return origin->emitError().append(
117 "could not find '", StructDefOp::getOperationName(), "' named \"",
118 actualStructType.getNameRef(), '"'
119 );
120 }
121 StructDefOp actualStruct = actualStructOpt.value().get();
122 if (actualStruct != expectedStruct) {
123 return genCompareErr(expectedStruct, origin, aspect)
124 .attachNote(actualStruct.getLoc())
125 .append("uses this type instead");
126 }
127 // Check for an EXACT match in the parameter list since it must reference the "self" type.
128 if (expectedStruct.getConstParamsAttr() != actualStructType.getParams()) {
129 // To make error messages more consistent and meaningful, if the parameters don't match
130 // because the actual type uses symbols that are not defined, generate an error about the
131 // undefined symbol(s).
132 if (ArrayAttr tyParams = actualStructType.getParams()) {
133 if (failed(verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
134 return failure();
135 }
136 }
137 // Otherwise, generate an error stating the parent struct type must be used.
138 return genCompareErr(expectedStruct, origin, aspect)
139 .attachNote(actualStruct.getLoc())
140 .append("should be type of this '", StructDefOp::getOperationName(), '\'');
141 }
142 } else {
143 return genCompareErr(expectedStruct, origin, aspect);
144 }
145 return success();
146}
147
148//===------------------------------------------------------------------===//
149// StructDefOp
150//===------------------------------------------------------------------===//
151
152StructType StructDefOp::getType(std::optional<ArrayAttr> constParams) {
153 auto pathRes = getPathFromRoot(*this);
154 assert(succeeded(pathRes)); // consistent with StructType::get() with invalid args
155 return StructType::get(pathRes.value(), constParams.value_or(getConstParamsAttr()));
156}
157
159 return buildStringViaCallback([this](llvm::raw_ostream &ss) {
160 FailureOr<SymbolRefAttr> pathToExpected = getPathFromRoot(*this);
161 if (succeeded(pathToExpected)) {
162 ss << pathToExpected.value();
163 } else {
164 // When there is a failure trying to get the resolved name of the struct,
165 // just print its symbol name directly.
166 ss << '@' << this->getSymName();
167 }
168 if (auto attr = this->getConstParamsAttr()) {
169 ss << '<' << attr << '>';
170 }
171 });
172}
173
174bool StructDefOp::hasParamNamed(StringAttr find) {
175 if (ArrayAttr params = this->getConstParamsAttr()) {
176 for (Attribute attr : params) {
177 assert(llvm::isa<FlatSymbolRefAttr>(attr)); // per ODS
178 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
179 return true;
180 }
181 }
182 }
183 return false;
184}
185
187 auto res = getPathFromRoot(*this);
188 assert(succeeded(res));
189 return res.value();
190}
191
192LogicalResult StructDefOp::verifySymbolUses(SymbolTableCollection &tables) {
193 if (ArrayAttr params = this->getConstParamsAttr()) {
194 // Ensure struct parameter names are unique
195 llvm::StringSet<> uniqNames;
196 for (Attribute attr : params) {
197 assert(llvm::isa<FlatSymbolRefAttr>(attr)); // per ODS
198 StringRef name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
199 if (!uniqNames.insert(name).second) {
200 return this->emitOpError().append("has more than one parameter named \"@", name, '"');
201 }
202 }
203 // Ensure they do not conflict with existing symbols
204 for (Attribute attr : params) {
205 auto res = lookupTopLevelSymbol(tables, llvm::cast<FlatSymbolRefAttr>(attr), *this, false);
206 if (succeeded(res)) {
207 return this->emitOpError()
208 .append("parameter name \"@")
209 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
210 .append("\" conflicts with an existing symbol")
211 .attachNote(res->get()->getLoc())
212 .append("symbol already defined here");
213 }
214 }
215 }
216 return success();
217}
218
219namespace {
220
221inline LogicalResult checkMainFuncParamType(Type pType, FuncDefOp inFunc, bool appendSelf) {
222 if (isSignalType(pType)) {
223 return success();
224 } else if (auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
225 if (isSignalType(arrayParamTy.getElementType())) {
226 return success();
227 }
228 }
229
230 std::string message = buildStringViaCallback([&inFunc, appendSelf](llvm::raw_ostream &ss) {
231 ss << "\"@" << COMPONENT_NAME_MAIN << "\" component \"@" << inFunc.getSymName()
232 << "\" function parameters must be one of: {";
233 if (appendSelf) {
234 ss << "!" << StructType::name << "<@" << COMPONENT_NAME_MAIN << ">, ";
235 }
236 ss << "!" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL << ">, ";
237 ss << "!" << ArrayType::name << "<.. x !" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL
238 << ">>}";
239 });
240 return inFunc.emitError(message);
241}
242
243inline LogicalResult verifyStructComputeConstrain(
244 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
245) {
246 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly.
247 assert(constrainFunc.hasAllowConstraintAttr());
248 assert(!computeFunc.hasAllowConstraintAttr());
249 assert(!constrainFunc.hasAllowWitnessAttr());
250 assert(computeFunc.hasAllowWitnessAttr());
251
252 // Verify parameter types are valid. Skip the first parameter of the "constrain" function; it is
253 // already checked via verifyFuncTypeConstrain() in Function/IR/Ops.cpp.
254 ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
255 ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
256 if (structDef.isMainComponent()) {
257 // Verify that the Struct has no parameters.
258 if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
259 return structDef.emitError().append(
260 "The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
261 );
262 }
263 // Verify the input parameter types are legal. The error message is explicit about what types
264 // are allowed so there is no benefit to report multiple errors if more than one parameter in
265 // the referenced function has an illegal type.
266 for (Type t : computeParams) {
267 if (failed(checkMainFuncParamType(t, computeFunc, false))) {
268 return failure(); // checkMainFuncParamType() already emits a sufficient error message
269 }
270 }
271 for (Type t : constrainParams) {
272 if (failed(checkMainFuncParamType(t, constrainFunc, true))) {
273 return failure(); // checkMainFuncParamType() already emits a sufficient error message
274 }
275 }
276 }
277
278 if (!typeListsUnify(computeParams, constrainParams)) {
279 return constrainFunc.emitError()
280 .append(
281 "expected \"@", FUNC_NAME_CONSTRAIN,
282 "\" function argument types (sans the first one) to match \"@", FUNC_NAME_COMPUTE,
283 "\" function argument types"
284 )
285 .attachNote(computeFunc.getLoc())
286 .append("\"@", FUNC_NAME_COMPUTE, "\" function defined here");
287 }
288
289 return success();
290}
291
292inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp productFunc) {
293 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly
294 assert(productFunc.hasAllowConstraintAttr());
295 assert(productFunc.hasAllowWitnessAttr());
296
297 // Verify parameter types are valid
298 ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
299 if (structDef.isMainComponent()) {
300 if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
301 return structDef.emitError().append(
302 "The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
303 );
304 }
305 for (Type t : productParams) {
306 if (failed(checkMainFuncParamType(t, productFunc, false))) {
307 return failure();
308 }
309 }
310 }
311
312 return success();
313}
314
315} // namespace
316
318 std::optional<FuncDefOp> foundCompute = std::nullopt;
319 std::optional<FuncDefOp> foundConstrain = std::nullopt;
320 std::optional<FuncDefOp> foundProduct = std::nullopt;
321 {
322 // Verify the following:
323 // 1. The only ops within the body are field and function definitions
324 // 2. The only functions defined in the struct are `@compute()` and `@constrain()`, or
325 // `@product()`
326 OwningEmitErrorFn emitError = getEmitOpErrFn(this);
327 Region &bodyRegion = getBodyRegion();
328 if (!bodyRegion.empty()) {
329 for (Operation &op : bodyRegion.front()) {
330 if (!llvm::isa<FieldDefOp>(op)) {
331 if (FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
332 if (funcDef.nameIsCompute()) {
333 if (foundProduct) {
334 return structFuncDefError(funcDef.getOperation())
335 << "found both \"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_PRODUCT
336 << "\" functions";
337 }
338 if (foundCompute) {
339 return structFuncDefError(funcDef.getOperation())
340 << "found multiple \"@" << FUNC_NAME_COMPUTE << "\" functions";
341 }
342 foundCompute = std::make_optional(funcDef);
343 } else if (funcDef.nameIsConstrain()) {
344 if (foundProduct) {
345 return structFuncDefError(funcDef.getOperation())
346 << "found both \"@" << FUNC_NAME_CONSTRAIN << "\" and \"@"
347 << FUNC_NAME_PRODUCT << "\" functions";
348 }
349 if (foundConstrain) {
350 return structFuncDefError(funcDef.getOperation())
351 << "found multiple \"@" << FUNC_NAME_CONSTRAIN << "\" functions";
352 }
353 foundConstrain = std::make_optional(funcDef);
354 } else if (funcDef.nameIsProduct()) {
355 if (foundCompute) {
356 return structFuncDefError(funcDef.getOperation())
357 << "found both \"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_PRODUCT
358 << "\" functions";
359 }
360 if (foundConstrain) {
361 return structFuncDefError(funcDef.getOperation())
362 << "found both \"@" << FUNC_NAME_CONSTRAIN << "\" and \"@"
363 << FUNC_NAME_PRODUCT << "\" functions";
364 }
365 if (foundProduct) {
366 return structFuncDefError(funcDef.getOperation())
367 << "found multiple \"@" << FUNC_NAME_PRODUCT << "\" functions";
368 }
369 foundProduct = std::make_optional(funcDef);
370 } else {
371 // Must do a little more than a simple call to '?.emitOpError()' to
372 // tag the error with correct location and correct op name.
373 return structFuncDefError(funcDef.getOperation())
374 << "found \"@" << funcDef.getSymName() << '"';
375 }
376 } else {
377 return op.emitOpError()
378 << "invalid operation in '" << StructDefOp::getOperationName() << "'; only '"
379 << FieldDefOp::getOperationName() << '\'' << " and '"
380 << FuncDefOp::getOperationName() << "' operations are permitted";
381 }
382 }
383 }
384 }
385
386 if (!foundCompute.has_value() && foundConstrain.has_value()) {
387 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_CONSTRAIN
388 << "\", missing \"@" << FUNC_NAME_COMPUTE << "\"";
389 }
390 if (!foundConstrain.has_value() && foundCompute.has_value()) {
391 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_COMPUTE
392 << "\", missing \"@" << FUNC_NAME_CONSTRAIN << "\"";
393 }
394 }
395
396 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
397 return structFuncDefError(getOperation())
398 << "could not find \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
399 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\"";
400 }
401
402 if (foundCompute && foundConstrain) {
403 return verifyStructComputeConstrain(*this, *foundCompute, *foundConstrain);
404 }
405 return verifyStructProduct(*this, *foundProduct);
406}
407
409 for (Operation &op : *getBody()) {
410 if (FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
411 if (fieldName.compare(fieldDef.getSymNameAttr()) == 0) {
412 return fieldDef;
413 }
414 }
415 }
416 return nullptr;
417}
418
419std::vector<FieldDefOp> StructDefOp::getFieldDefs() {
420 std::vector<FieldDefOp> res;
421 for (Operation &op : *getBody()) {
422 if (FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
423 res.push_back(fieldDef);
424 }
425 }
426 return res;
427}
428
430 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_COMPUTE));
431}
432
434 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_CONSTRAIN));
435}
436
438 if (auto *computeFunc = lookupSymbol(FUNC_NAME_COMPUTE)) {
439 return llvm::dyn_cast<FuncDefOp>(computeFunc);
440 }
441 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
442}
443
445 if (auto *constrainFunc = lookupSymbol(FUNC_NAME_CONSTRAIN)) {
446 return llvm::dyn_cast<FuncDefOp>(constrainFunc);
447 }
448 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
449}
450
452
453//===------------------------------------------------------------------===//
454// FieldDefOp
455//===------------------------------------------------------------------===//
456
458 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
459 bool isColumn
460) {
461 Properties &props = odsState.getOrAddProperties<Properties>();
462 props.setSymName(sym_name);
463 props.setType(type);
464 if (isColumn) {
465 props.column = odsBuilder.getUnitAttr();
466 }
467}
468
470 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type, bool isColumn
471) {
472 build(odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isColumn);
473}
474
476 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
477 ArrayRef<NamedAttribute> attributes, bool isColumn
478) {
479 assert(operands.size() == 0u && "mismatched number of parameters");
480 odsState.addOperands(operands);
481 odsState.addAttributes(attributes);
482 assert(resultTypes.size() == 0u && "mismatched number of return types");
483 odsState.addTypes(resultTypes);
484 if (isColumn) {
485 odsState.getOrAddProperties<Properties>().column = odsBuilder.getUnitAttr();
486 }
487}
488
489void FieldDefOp::setPublicAttr(bool newValue) {
490 if (newValue) {
491 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
492 } else {
493 getOperation()->removeAttr(PublicAttr::name);
494 }
495}
496
497static LogicalResult
498verifyFieldDefTypeImpl(Type fieldType, SymbolTableCollection &tables, Operation *origin) {
499 if (StructType fieldStructType = llvm::dyn_cast<StructType>(fieldType)) {
500 // Special case for StructType verifies that the field type can resolve and that it is NOT the
501 // parent struct (i.e., struct fields cannot create circular references).
502 auto fieldTypeRes = verifyStructTypeResolution(tables, fieldStructType, origin);
503 if (failed(fieldTypeRes)) {
504 return failure(); // above already emits a sufficient error message
505 }
506 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(origin);
507 assert(succeeded(parentRes) && "FieldDefOp parent is always StructDefOp"); // per ODS def
508 if (fieldTypeRes.value() == parentRes.value()) {
509 return origin->emitOpError()
510 .append("type is circular")
511 .attachNote(parentRes.value().getLoc())
512 .append("references parent component defined here");
513 }
514 return success();
515 } else {
516 return verifyTypeResolution(tables, origin, fieldType);
517 }
518}
519
520LogicalResult FieldDefOp::verifySymbolUses(SymbolTableCollection &tables) {
521 Type fieldType = this->getType();
522 if (failed(verifyFieldDefTypeImpl(fieldType, tables, *this))) {
523 return failure();
524 }
525
526 if (!getColumn()) {
527 return success();
528 }
529 // If the field is marked as a column only a small subset of types are allowed.
530 if (!isValidColumnType(getType(), tables, *this)) {
531 return emitOpError() << "marked as column can only contain felts, arrays of column types, or "
532 "structs with columns, but field has type "
533 << getType();
534 }
535 return success();
536}
537
538//===------------------------------------------------------------------===//
539// FieldRefOp implementations
540//===------------------------------------------------------------------===//
541namespace {
542
543FailureOr<SymbolLookupResult<FieldDefOp>>
544getFieldDefOpImpl(FieldRefOpInterface refOp, SymbolTableCollection &tables, StructType tyStruct) {
545 Operation *op = refOp.getOperation();
546 auto structDefRes = tyStruct.getDefinition(tables, op);
547 if (failed(structDefRes)) {
548 return failure(); // getDefinition() already emits a sufficient error message
549 }
551 tables, SymbolRefAttr::get(refOp->getContext(), refOp.getFieldName()),
552 std::move(*structDefRes), op
553 );
554 if (failed(res)) {
555 return refOp->emitError() << "could not find '" << FieldDefOp::getOperationName()
556 << "' named \"@" << refOp.getFieldName() << "\" in \""
557 << tyStruct.getNameRef() << '"';
558 }
559 return std::move(res.value());
560}
561
562static FailureOr<SymbolLookupResult<FieldDefOp>>
563findField(FieldRefOpInterface refOp, SymbolTableCollection &tables) {
564 // Ensure the base component/struct type reference can be resolved.
565 StructType tyStruct = refOp.getStructType();
566 if (failed(tyStruct.verifySymbolRef(tables, refOp.getOperation()))) {
567 return failure();
568 }
569 // Ensure the field name can be resolved in that struct.
570 return getFieldDefOpImpl(refOp, tables, tyStruct);
571}
572
573static LogicalResult verifySymbolUsesImpl(
574 FieldRefOpInterface refOp, SymbolTableCollection &tables, SymbolLookupResult<FieldDefOp> &field
575) {
576 // Ensure the type of the referenced field declaration matches the type used in this op.
577 Type actualType = refOp.getVal().getType();
578 Type fieldType = field.get().getType();
579 if (!typesUnify(actualType, fieldType, field.getIncludeSymNames())) {
580 return refOp->emitOpError() << "has wrong type; expected " << fieldType << ", got "
581 << actualType;
582 }
583 // Ensure any SymbolRef used in the type are valid
584 return verifyTypeResolution(tables, refOp.getOperation(), actualType);
585}
586
587LogicalResult verifySymbolUsesImpl(FieldRefOpInterface refOp, SymbolTableCollection &tables) {
588 // Ensure the field name can be resolved in that struct.
589 auto field = findField(refOp, tables);
590 if (failed(field)) {
591 return field; // getFieldDefOp() already emits a sufficient error message
592 }
593 return verifySymbolUsesImpl(refOp, tables, *field);
594}
595
596} // namespace
597
598FailureOr<SymbolLookupResult<FieldDefOp>>
599FieldRefOpInterface::getFieldDefOp(SymbolTableCollection &tables) {
600 return getFieldDefOpImpl(*this, tables, getStructType());
601}
602
603LogicalResult FieldReadOp::verifySymbolUses(SymbolTableCollection &tables) {
604 auto field = findField(*this, tables);
605 if (failed(field)) {
606 return failure();
607 }
608 if (failed(verifySymbolUsesImpl(*this, tables, *field))) {
609 return failure();
610 }
611 // If the field is not a column and an offset was specified then fail to validate
612 if (!field->get().getColumn() && getTableOffset().has_value()) {
613 return emitOpError("cannot read with table offset from a field that is not a column")
614 .attachNote(field->get().getLoc())
615 .append("field defined here");
616 }
617
618 return success();
619}
620
621LogicalResult FieldWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
622 // Ensure the write op only targets fields in the current struct.
623 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
624 if (failed(getParentRes)) {
625 return failure(); // verifyInStruct() already emits a sufficient error message
626 }
627 if (failed(checkSelfType(tables, *getParentRes, getComponent().getType(), *this, "base value"))) {
628 return failure(); // checkSelfType() already emits a sufficient error message
629 }
630 // Perform the standard field ref checks.
631 return verifySymbolUsesImpl(*this, tables);
632}
633
634//===------------------------------------------------------------------===//
635// FieldReadOp
636//===------------------------------------------------------------------===//
637
639 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr field
640) {
641 Properties &props = state.getOrAddProperties<Properties>();
642 props.setFieldName(FlatSymbolRefAttr::get(field));
643 state.addTypes(resultType);
644 state.addOperands(component);
646}
647
649 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr field,
650 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
651) {
652 // '!mapOperands.empty()' implies 'numDims.has_value()'
653 assert(mapOperands.empty() || numDims.has_value());
654 state.addOperands(component);
655 state.addTypes(resultType);
656 if (numDims.has_value()) {
658 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
659 );
660 } else {
662 }
663 Properties &props = state.getOrAddProperties<Properties>();
664 props.setFieldName(FlatSymbolRefAttr::get(field));
665 props.setTableOffset(dist);
666}
667
669 OpBuilder & /*odsBuilder*/, OperationState &odsState, TypeRange resultTypes,
670 ValueRange operands, ArrayRef<NamedAttribute> attrs
671) {
672 odsState.addTypes(resultTypes);
673 odsState.addOperands(operands);
674 odsState.addAttributes(attrs);
675}
676
677LogicalResult FieldReadOp::verify() {
678 SmallVector<AffineMapAttr, 1> mapAttrs;
679 if (AffineMapAttr map =
680 llvm::dyn_cast_if_present<AffineMapAttr>(getTableOffset().value_or(nullptr))) {
681 mapAttrs.push_back(map);
682 }
684 getMapOperands(), getNumDimsPerMap(), mapAttrs, *this
685 );
686}
687
688//===------------------------------------------------------------------===//
689// CreateStructOp
690//===------------------------------------------------------------------===//
691
692void CreateStructOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
693 setNameFn(getResult(), "self");
694}
695
696LogicalResult CreateStructOp::verifySymbolUses(SymbolTableCollection &tables) {
697 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
698 if (failed(getParentRes)) {
699 return failure(); // verifyInStruct() already emits a sufficient error message
700 }
701 if (failed(checkSelfType(tables, *getParentRes, this->getType(), *this, "result"))) {
702 return failure();
703 }
704 return success();
705}
706
707} // namespace llzk::component
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:696
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:692
::mlir::TypedValue<::llzk::component::StructType > getResult()
Definition Ops.h.inc:143
FoldAdaptor::Properties Properties
Definition Ops.h.inc:302
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:489
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isColumn=false)
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:332
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:520
FoldAdaptor::Properties Properties
Definition Ops.h.inc:601
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:654
::llvm::LogicalResult verify()
Definition Ops.cpp:677
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr field)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:910
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:603
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:905
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:599
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::llvm::StringRef getFieldName()
Gets the field name attribute value from the FieldRefOp.
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:914
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:621
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
FieldDefOp getFieldDef(::mlir::StringAttr fieldName)
Gets the FieldDefOp that defines the field in this structure with the given name, if present.
Definition Ops.cpp:408
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:1176
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:192
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1152
::llzk::function::FuncDefOp getConstrainOrProductFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:444
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
Definition Ops.cpp:419
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1590
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
Definition Ops.cpp:186
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:433
::llzk::function::FuncDefOp getComputeOrProductFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:437
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:317
::mlir::ArrayAttr getConstParamsAttr()
Definition Ops.h.inc:1194
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:429
bool isMainComponent()
Return true iff this StructDefOp is named "Main".
Definition Ops.cpp:451
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
Definition Ops.cpp:158
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:47
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
Definition Types.cpp:72
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:38
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:208
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:770
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:729
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:778
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:947
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:774
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:583
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:200
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:721
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
Definition Ops.cpp:41
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
Definition Ops.cpp:86
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:108
FailureOr< StructDefOp > verifyInStruct(Operation *op)
Definition Ops.cpp:43
bool isInStructFunctionNamed(Operation *op, char const *funcName)
Definition Ops.cpp:52
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:225
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
Definition Constants.h:23
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:29
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
bool isSignalType(Type type)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
Definition Constants.h:16
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
Definition Ops.h.inc:190
void setFieldName(const ::mlir::FlatSymbolRefAttr &propValue)
Definition Ops.h.inc:449