LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Struct op implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Constants.h"
18
19#include <mlir/IR/IRMapping.h>
20#include <mlir/IR/OpImplementation.h>
21
22#include <llvm/ADT/MapVector.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/StringSet.h>
25
26#include <optional>
27
28// TableGen'd implementation files
30
31// TableGen'd implementation files
32#define GET_OP_CLASSES
34
35using namespace mlir;
36using namespace llzk::array;
37using namespace llzk::function;
38
39namespace llzk::component {
40
41bool isInStruct(Operation *op) { return succeeded(getParentOfType<StructDefOp>(op)); }
42
43FailureOr<StructDefOp> verifyInStruct(Operation *op) {
44 FailureOr<StructDefOp> res = getParentOfType<StructDefOp>(op);
45 if (failed(res)) {
46 return op->emitOpError() << "only valid within a '" << StructDefOp::getOperationName()
47 << "' ancestor";
48 }
49 return res;
50}
51
52bool isInStructFunctionNamed(Operation *op, char const *funcName) {
53 FailureOr<FuncDefOp> parentFuncOpt = getParentOfType<FuncDefOp>(op);
54 if (succeeded(parentFuncOpt)) {
55 FuncDefOp parentFunc = parentFuncOpt.value();
56 if (isInStruct(parentFunc.getOperation())) {
57 if (parentFunc.getSymName().compare(funcName) == 0) {
58 return true;
59 }
60 }
61 }
62 return false;
63}
64
65// Again, only valid/implemented for StructDefOp
66template <> LogicalResult SetFuncAllowAttrs<StructDefOp>::verifyTrait(Operation *structOp) {
67 assert(llvm::isa<StructDefOp>(structOp));
68 llvm::cast<StructDefOp>(structOp).getBody()->walk([](FuncDefOp funcDef) {
69 if (funcDef.nameIsConstrain()) {
70 funcDef.setAllowConstraintAttr();
71 funcDef.setAllowWitnessAttr(false);
72 } else if (funcDef.nameIsCompute()) {
73 funcDef.setAllowConstraintAttr(false);
74 funcDef.setAllowWitnessAttr();
75 } else if (funcDef.nameIsProduct()) {
76 funcDef.setAllowConstraintAttr();
77 funcDef.setAllowWitnessAttr();
78 }
79 });
80 return success();
81}
82
83InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect) {
84 std::string prefix = std::string();
85 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
86 prefix += "\"@";
87 prefix += symbol.getName();
88 prefix += "\" ";
89 }
90 return origin->emitOpError().append(
91 prefix, "must use type of its ancestor '", StructDefOp::getOperationName(), "' \"",
92 expected.getHeaderString(), "\" as ", aspect, " type"
93 );
94}
95
96static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
97 return origin->emitError() << '\'' << StructDefOp::getOperationName() << "' op "
98 << "must define either only a \"@" << FUNC_NAME_PRODUCT
99 << "\" function, or both \"@" << FUNC_NAME_COMPUTE << "\" and \"@"
100 << FUNC_NAME_CONSTRAIN << "\" functions; ";
101}
102
105LogicalResult checkSelfType(
106 SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin,
107 const char *aspect
108) {
109 if (StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
110 auto actualStructOpt =
111 lookupTopLevelSymbol<StructDefOp>(tables, actualStructType.getNameRef(), origin);
112 if (failed(actualStructOpt)) {
113 return origin->emitError().append(
114 "could not find '", StructDefOp::getOperationName(), "' named \"",
115 actualStructType.getNameRef(), '"'
116 );
117 }
118 StructDefOp actualStruct = actualStructOpt.value().get();
119 if (actualStruct != expectedStruct) {
120 return genCompareErr(expectedStruct, origin, aspect)
121 .attachNote(actualStruct.getLoc())
122 .append("uses this type instead");
123 }
124 // Check for an EXACT match in the parameter list since it must reference the "self" type.
125 if (expectedStruct.getConstParamsAttr() != actualStructType.getParams()) {
126 // To make error messages more consistent and meaningful, if the parameters don't match
127 // because the actual type uses symbols that are not defined, generate an error about the
128 // undefined symbol(s).
129 if (ArrayAttr tyParams = actualStructType.getParams()) {
130 if (failed(verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
131 return failure();
132 }
133 }
134 // Otherwise, generate an error stating the parent struct type must be used.
135 return genCompareErr(expectedStruct, origin, aspect)
136 .attachNote(actualStruct.getLoc())
137 .append("should be type of this '", StructDefOp::getOperationName(), '\'');
138 }
139 } else {
140 return genCompareErr(expectedStruct, origin, aspect);
141 }
142 return success();
143}
144
145//===------------------------------------------------------------------===//
146// StructDefOp
147//===------------------------------------------------------------------===//
148
149StructType StructDefOp::getType(std::optional<ArrayAttr> constParams) {
150 auto pathRes = getPathFromRoot(*this);
151 assert(succeeded(pathRes)); // consistent with StructType::get() with invalid args
152 return StructType::get(pathRes.value(), constParams.value_or(getConstParamsAttr()));
153}
154
156 return buildStringViaCallback([this](llvm::raw_ostream &ss) {
157 FailureOr<SymbolRefAttr> pathToExpected = getPathFromRoot(*this);
158 if (succeeded(pathToExpected)) {
159 ss << pathToExpected.value();
160 } else {
161 // When there is a failure trying to get the resolved name of the struct,
162 // just print its symbol name directly.
163 ss << '@' << this->getSymName();
164 }
165 if (auto attr = this->getConstParamsAttr()) {
166 ss << '<' << attr << '>';
167 }
168 });
169}
170
171bool StructDefOp::hasParamNamed(StringAttr find) {
172 if (ArrayAttr params = this->getConstParamsAttr()) {
173 for (Attribute attr : params) {
174 assert(llvm::isa<FlatSymbolRefAttr>(attr)); // per ODS
175 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
176 return true;
177 }
178 }
179 }
180 return false;
181}
182
184 auto res = getPathFromRoot(*this);
185 assert(succeeded(res));
186 return res.value();
187}
188
189LogicalResult StructDefOp::verifySymbolUses(SymbolTableCollection &tables) {
190 if (ArrayAttr params = this->getConstParamsAttr()) {
191 // Ensure struct parameter names are unique
192 llvm::StringSet<> uniqNames;
193 for (Attribute attr : params) {
194 assert(llvm::isa<FlatSymbolRefAttr>(attr)); // per ODS
195 StringRef name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
196 if (!uniqNames.insert(name).second) {
197 return this->emitOpError().append("has more than one parameter named \"@", name, '"');
198 }
199 }
200 // Ensure they do not conflict with existing symbols
201 for (Attribute attr : params) {
202 auto res = lookupTopLevelSymbol(tables, llvm::cast<FlatSymbolRefAttr>(attr), *this, false);
203 if (succeeded(res)) {
204 return this->emitOpError()
205 .append("parameter name \"@")
206 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
207 .append("\" conflicts with an existing symbol")
208 .attachNote(res->get()->getLoc())
209 .append("symbol already defined here");
210 }
211 }
212 }
213 return success();
214}
215
216namespace {
217
218inline LogicalResult checkMainFuncParamType(Type pType, FuncDefOp inFunc, bool appendSelf) {
219 if (isSignalType(pType)) {
220 return success();
221 } else if (auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
222 if (isSignalType(arrayParamTy.getElementType())) {
223 return success();
224 }
225 }
226
227 std::string message = buildStringViaCallback([&inFunc, appendSelf](llvm::raw_ostream &ss) {
228 ss << "\"@" << COMPONENT_NAME_MAIN << "\" component \"@" << inFunc.getSymName()
229 << "\" function parameters must be one of: {";
230 if (appendSelf) {
231 ss << "!" << StructType::name << "<@" << COMPONENT_NAME_MAIN << ">, ";
232 }
233 ss << "!" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL << ">, ";
234 ss << "!" << ArrayType::name << "<.. x !" << StructType::name << "<@" << COMPONENT_NAME_SIGNAL
235 << ">>}";
236 });
237 return inFunc.emitError(message);
238}
239
240inline LogicalResult verifyStructComputeConstrain(
241 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
242) {
243 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly.
244 assert(constrainFunc.hasAllowConstraintAttr());
245 assert(!computeFunc.hasAllowConstraintAttr());
246 assert(!constrainFunc.hasAllowWitnessAttr());
247 assert(computeFunc.hasAllowWitnessAttr());
248
249 // Verify parameter types are valid. Skip the first parameter of the "constrain" function; it is
250 // already checked via verifyFuncTypeConstrain() in Function/IR/Ops.cpp.
251 ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
252 ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
253 if (structDef.isMainComponent()) {
254 // Verify that the Struct has no parameters.
255 if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
256 return structDef.emitError().append(
257 "The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
258 );
259 }
260 // Verify the input parameter types are legal. The error message is explicit about what types
261 // are allowed so there is no benefit to report multiple errors if more than one parameter in
262 // the referenced function has an illegal type.
263 for (Type t : computeParams) {
264 if (failed(checkMainFuncParamType(t, computeFunc, false))) {
265 return failure(); // checkMainFuncParamType() already emits a sufficient error message
266 }
267 }
268 for (Type t : constrainParams) {
269 if (failed(checkMainFuncParamType(t, constrainFunc, true))) {
270 return failure(); // checkMainFuncParamType() already emits a sufficient error message
271 }
272 }
273 }
274
275 if (!typeListsUnify(computeParams, constrainParams)) {
276 return constrainFunc.emitError()
277 .append(
278 "expected \"@", FUNC_NAME_CONSTRAIN,
279 "\" function argument types (sans the first one) to match \"@", FUNC_NAME_COMPUTE,
280 "\" function argument types"
281 )
282 .attachNote(computeFunc.getLoc())
283 .append("\"@", FUNC_NAME_COMPUTE, "\" function defined here");
284 }
285
286 return success();
287}
288
289inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp productFunc) {
290 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly
291 assert(productFunc.hasAllowConstraintAttr());
292 assert(productFunc.hasAllowWitnessAttr());
293
294 // Verify parameter types are valid
295 ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
296 if (structDef.isMainComponent()) {
297 if (!isNullOrEmpty(structDef.getConstParamsAttr())) {
298 return structDef.emitError().append(
299 "The \"@", COMPONENT_NAME_MAIN, "\" component must have no parameters"
300 );
301 }
302 for (Type t : productParams) {
303 if (failed(checkMainFuncParamType(t, productFunc, false))) {
304 return failure();
305 }
306 }
307 }
308
309 return success();
310}
311
312} // namespace
313
315 std::optional<FuncDefOp> foundCompute = std::nullopt;
316 std::optional<FuncDefOp> foundConstrain = std::nullopt;
317 std::optional<FuncDefOp> foundProduct = std::nullopt;
318 {
319 // Verify the following:
320 // 1. The only ops within the body are field and function definitions
321 // 2. The only functions defined in the struct are `@compute()` and `@constrain()`, or
322 // `@product()`
323 OwningEmitErrorFn emitError = getEmitOpErrFn(this);
324 for (Operation &op : *getBody()) {
325 if (!llvm::isa<FieldDefOp>(op)) {
326 if (FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
327 if (funcDef.nameIsCompute()) {
328 if (foundProduct) {
329 return structFuncDefError(funcDef.getOperation())
330 << "found both \"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_PRODUCT
331 << "\" functions";
332 }
333 if (foundCompute) {
334 return structFuncDefError(funcDef.getOperation())
335 << "found multiple \"@" << FUNC_NAME_COMPUTE << "\" functions";
336 }
337 foundCompute = std::make_optional(funcDef);
338 } else if (funcDef.nameIsConstrain()) {
339 if (foundProduct) {
340 return structFuncDefError(funcDef.getOperation())
341 << "found both \"@" << FUNC_NAME_CONSTRAIN << "\" and \"@" << FUNC_NAME_PRODUCT
342 << "\" functions";
343 }
344 if (foundConstrain) {
345 return structFuncDefError(funcDef.getOperation())
346 << "found multiple \"@" << FUNC_NAME_CONSTRAIN << "\" functions";
347 }
348 foundConstrain = std::make_optional(funcDef);
349 } else if (funcDef.nameIsProduct()) {
350 if (foundCompute) {
351 return structFuncDefError(funcDef.getOperation())
352 << "found both \"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_PRODUCT
353 << "\" functions";
354 }
355 if (foundConstrain) {
356 return structFuncDefError(funcDef.getOperation())
357 << "found both \"@" << FUNC_NAME_CONSTRAIN << "\" and \"@" << FUNC_NAME_PRODUCT
358 << "\" functions";
359 }
360 if (foundProduct) {
361 return structFuncDefError(funcDef.getOperation())
362 << "found multiple \"@" << FUNC_NAME_PRODUCT << "\" functions";
363 }
364 foundProduct = std::make_optional(funcDef);
365 } else {
366 // Must do a little more than a simple call to '?.emitOpError()' to
367 // tag the error with correct location and correct op name.
368 return structFuncDefError(funcDef.getOperation())
369 << "found \"@" << funcDef.getSymName() << '"';
370 }
371 } else {
372 return op.emitOpError() << "invalid operation in '" << StructDefOp::getOperationName()
373 << "'; only '" << FieldDefOp::getOperationName() << '\''
374 << " and '" << FuncDefOp::getOperationName()
375 << "' operations are permitted";
376 }
377 }
378 }
379
380 if (!foundCompute.has_value() && foundConstrain.has_value()) {
381 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_CONSTRAIN
382 << "\", missing \"@" << FUNC_NAME_COMPUTE << "\"";
383 }
384 if (!foundConstrain.has_value() && foundCompute.has_value()) {
385 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_COMPUTE
386 << "\", missing \"@" << FUNC_NAME_CONSTRAIN << "\"";
387 }
388 }
389
390 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
391 return structFuncDefError(getOperation())
392 << "could not find \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
393 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\"";
394 }
395
396 if (foundCompute && foundConstrain) {
397 return verifyStructComputeConstrain(*this, *foundCompute, *foundConstrain);
398 }
399 return verifyStructProduct(*this, *foundProduct);
400}
401
403 for (Operation &op : *getBody()) {
404 if (FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
405 if (fieldName.compare(fieldDef.getSymNameAttr()) == 0) {
406 return fieldDef;
407 }
408 }
409 }
410 return nullptr;
411}
412
413std::vector<FieldDefOp> StructDefOp::getFieldDefs() {
414 std::vector<FieldDefOp> res;
415 for (Operation &op : *getBody()) {
416 if (FieldDefOp fieldDef = llvm::dyn_cast_if_present<FieldDefOp>(op)) {
417 res.push_back(fieldDef);
418 }
419 }
420 return res;
421}
422
424 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_COMPUTE));
425}
426
428 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_CONSTRAIN));
429}
430
432 if (auto *computeFunc = lookupSymbol(FUNC_NAME_COMPUTE)) {
433 return llvm::dyn_cast<FuncDefOp>(computeFunc);
434 }
435 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
436}
437
439 if (auto *constrainFunc = lookupSymbol(FUNC_NAME_CONSTRAIN)) {
440 return llvm::dyn_cast<FuncDefOp>(constrainFunc);
441 }
442 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
443}
444
446
447//===------------------------------------------------------------------===//
448// FieldDefOp
449//===------------------------------------------------------------------===//
450
452 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
453 bool isColumn
454) {
455 Properties &props = odsState.getOrAddProperties<Properties>();
456 props.setSymName(sym_name);
457 props.setType(type);
458 if (isColumn) {
459 props.column = odsBuilder.getUnitAttr();
460 }
461}
462
464 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type, bool isColumn
465) {
466 build(odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isColumn);
467}
468
470 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
471 ArrayRef<NamedAttribute> attributes, bool isColumn
472) {
473 assert(operands.size() == 0u && "mismatched number of parameters");
474 odsState.addOperands(operands);
475 odsState.addAttributes(attributes);
476 assert(resultTypes.size() == 0u && "mismatched number of return types");
477 odsState.addTypes(resultTypes);
478 if (isColumn) {
479 odsState.getOrAddProperties<Properties>().column = odsBuilder.getUnitAttr();
480 }
481}
482
483void FieldDefOp::setPublicAttr(bool newValue) {
484 if (newValue) {
485 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
486 } else {
487 getOperation()->removeAttr(PublicAttr::name);
488 }
489}
490
491static LogicalResult
492verifyFieldDefTypeImpl(Type fieldType, SymbolTableCollection &tables, Operation *origin) {
493 if (StructType fieldStructType = llvm::dyn_cast<StructType>(fieldType)) {
494 // Special case for StructType verifies that the field type can resolve and that it is NOT the
495 // parent struct (i.e., struct fields cannot create circular references).
496 auto fieldTypeRes = verifyStructTypeResolution(tables, fieldStructType, origin);
497 if (failed(fieldTypeRes)) {
498 return failure(); // above already emits a sufficient error message
499 }
500 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(origin);
501 assert(succeeded(parentRes) && "FieldDefOp parent is always StructDefOp"); // per ODS def
502 if (fieldTypeRes.value() == parentRes.value()) {
503 return origin->emitOpError()
504 .append("type is circular")
505 .attachNote(parentRes.value().getLoc())
506 .append("references parent component defined here");
507 }
508 return success();
509 } else {
510 return verifyTypeResolution(tables, origin, fieldType);
511 }
512}
513
514LogicalResult FieldDefOp::verifySymbolUses(SymbolTableCollection &tables) {
515 Type fieldType = this->getType();
516 if (failed(verifyFieldDefTypeImpl(fieldType, tables, *this))) {
517 return failure();
518 }
519
520 if (!getColumn()) {
521 return success();
522 }
523 // If the field is marked as a column only a small subset of types are allowed.
524 if (!isValidColumnType(getType(), tables, *this)) {
525 return emitOpError() << "marked as column can only contain felts, arrays of column types, or "
526 "structs with columns, but field has type "
527 << getType();
528 }
529 return success();
530}
531
532//===------------------------------------------------------------------===//
533// FieldRefOp implementations
534//===------------------------------------------------------------------===//
535namespace {
536
537FailureOr<SymbolLookupResult<FieldDefOp>>
538getFieldDefOpImpl(FieldRefOpInterface refOp, SymbolTableCollection &tables, StructType tyStruct) {
539 Operation *op = refOp.getOperation();
540 auto structDefRes = tyStruct.getDefinition(tables, op);
541 if (failed(structDefRes)) {
542 return failure(); // getDefinition() already emits a sufficient error message
543 }
545 tables, SymbolRefAttr::get(refOp->getContext(), refOp.getFieldName()),
546 std::move(*structDefRes), op
547 );
548 if (failed(res)) {
549 return refOp->emitError() << "could not find '" << FieldDefOp::getOperationName()
550 << "' named \"@" << refOp.getFieldName() << "\" in \""
551 << tyStruct.getNameRef() << '"';
552 }
553 return std::move(res.value());
554}
555
556static FailureOr<SymbolLookupResult<FieldDefOp>>
557findField(FieldRefOpInterface refOp, SymbolTableCollection &tables) {
558 // Ensure the base component/struct type reference can be resolved.
559 StructType tyStruct = refOp.getStructType();
560 if (failed(tyStruct.verifySymbolRef(tables, refOp.getOperation()))) {
561 return failure();
562 }
563 // Ensure the field name can be resolved in that struct.
564 return getFieldDefOpImpl(refOp, tables, tyStruct);
565}
566
567static LogicalResult verifySymbolUsesImpl(
568 FieldRefOpInterface refOp, SymbolTableCollection &tables, SymbolLookupResult<FieldDefOp> &field
569) {
570 // Ensure the type of the referenced field declaration matches the type used in this op.
571 Type actualType = refOp.getVal().getType();
572 Type fieldType = field.get().getType();
573 if (!typesUnify(actualType, fieldType, field.getIncludeSymNames())) {
574 return refOp->emitOpError() << "has wrong type; expected " << fieldType << ", got "
575 << actualType;
576 }
577 // Ensure any SymbolRef used in the type are valid
578 return verifyTypeResolution(tables, refOp.getOperation(), actualType);
579}
580
581LogicalResult verifySymbolUsesImpl(FieldRefOpInterface refOp, SymbolTableCollection &tables) {
582 // Ensure the field name can be resolved in that struct.
583 auto field = findField(refOp, tables);
584 if (failed(field)) {
585 return field; // getFieldDefOp() already emits a sufficient error message
586 }
587 return verifySymbolUsesImpl(refOp, tables, *field);
588}
589
590} // namespace
591
592FailureOr<SymbolLookupResult<FieldDefOp>>
593FieldRefOpInterface::getFieldDefOp(SymbolTableCollection &tables) {
594 return getFieldDefOpImpl(*this, tables, getStructType());
595}
596
597LogicalResult FieldReadOp::verifySymbolUses(SymbolTableCollection &tables) {
598 auto field = findField(*this, tables);
599 if (failed(field)) {
600 return failure();
601 }
602 if (failed(verifySymbolUsesImpl(*this, tables, *field))) {
603 return failure();
604 }
605 // If the field is not a column and an offset was specified then fail to validate
606 if (!field->get().getColumn() && getTableOffset().has_value()) {
607 return emitOpError("cannot read with table offset from a field that is not a column")
608 .attachNote(field->get().getLoc())
609 .append("field defined here");
610 }
611
612 return success();
613}
614
615LogicalResult FieldWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
616 // Ensure the write op only targets fields in the current struct.
617 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
618 if (failed(getParentRes)) {
619 return failure(); // verifyInStruct() already emits a sufficient error message
620 }
621 if (failed(checkSelfType(tables, *getParentRes, getComponent().getType(), *this, "base value"))) {
622 return failure(); // checkSelfType() already emits a sufficient error message
623 }
624 // Perform the standard field ref checks.
625 return verifySymbolUsesImpl(*this, tables);
626}
627
628//===------------------------------------------------------------------===//
629// FieldReadOp
630//===------------------------------------------------------------------===//
631
633 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr field
634) {
635 Properties &props = state.getOrAddProperties<Properties>();
636 props.setFieldName(FlatSymbolRefAttr::get(field));
637 state.addTypes(resultType);
638 state.addOperands(component);
640}
641
643 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr field,
644 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
645) {
646 // '!mapOperands.empty()' implies 'numDims.has_value()'
647 assert(mapOperands.empty() || numDims.has_value());
648 state.addOperands(component);
649 state.addTypes(resultType);
650 if (numDims.has_value()) {
652 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
653 );
654 } else {
656 }
657 Properties &props = state.getOrAddProperties<Properties>();
658 props.setFieldName(FlatSymbolRefAttr::get(field));
659 props.setTableOffset(dist);
660}
661
663 OpBuilder & /*odsBuilder*/, OperationState &odsState, TypeRange resultTypes,
664 ValueRange operands, ArrayRef<NamedAttribute> attrs
665) {
666 odsState.addTypes(resultTypes);
667 odsState.addOperands(operands);
668 odsState.addAttributes(attrs);
669}
670
671LogicalResult FieldReadOp::verify() {
672 SmallVector<AffineMapAttr, 1> mapAttrs;
673 if (AffineMapAttr map =
674 llvm::dyn_cast_if_present<AffineMapAttr>(getTableOffset().value_or(nullptr))) {
675 mapAttrs.push_back(map);
676 }
678 getMapOperands(), getNumDimsPerMap(), mapAttrs, *this
679 );
680}
681
682//===------------------------------------------------------------------===//
683// CreateStructOp
684//===------------------------------------------------------------------===//
685
686void CreateStructOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
687 setNameFn(getResult(), "self");
688}
689
690LogicalResult CreateStructOp::verifySymbolUses(SymbolTableCollection &tables) {
691 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
692 if (failed(getParentRes)) {
693 return failure(); // verifyInStruct() already emits a sufficient error message
694 }
695 if (failed(checkSelfType(tables, *getParentRes, this->getType(), *this, "result"))) {
696 return failure();
697 }
698 return success();
699}
700
701} // namespace llzk::component
MlirStringRef name
Definition Poly.cpp:48
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:690
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:686
::mlir::TypedValue<::llzk::component::StructType > getResult()
Definition Ops.h.inc:143
FoldAdaptor::Properties Properties
Definition Ops.h.inc:302
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:483
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isColumn=false)
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:332
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:514
FoldAdaptor::Properties Properties
Definition Ops.h.inc:601
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:654
::llvm::LogicalResult verify()
Definition Ops.cpp:671
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr field)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:910
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:597
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:905
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:593
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::llvm::StringRef getFieldName()
Gets the field name attribute value from the FieldRefOp.
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:914
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:615
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
FieldDefOp getFieldDef(::mlir::StringAttr fieldName)
Gets the FieldDefOp that defines the field in this structure with the given name, if present.
Definition Ops.cpp:402
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:189
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1152
::llzk::function::FuncDefOp getConstrainOrProductFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:438
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
Definition Ops.cpp:413
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1590
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
Definition Ops.cpp:183
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:427
::llzk::function::FuncDefOp getComputeOrProductFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:431
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:314
::mlir::ArrayAttr getConstParamsAttr()
Definition Ops.h.inc:1194
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:423
bool isMainComponent()
Return true iff this StructDefOp is named "Main".
Definition Ops.cpp:445
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
Definition Ops.cpp:155
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:47
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
Definition Types.cpp:72
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:38
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:195
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:758
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:717
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:766
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:947
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:762
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:571
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:187
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:709
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
Definition Ops.cpp:41
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
Definition Ops.cpp:83
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:105
FailureOr< StructDefOp > verifyInStruct(Operation *op)
Definition Ops.cpp:43
bool isInStructFunctionNamed(Operation *op, char const *funcName)
Definition Ops.cpp:52
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:225
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
Definition Constants.h:23
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:29
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
bool isSignalType(Type type)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
Definition Constants.h:16
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
Definition Ops.h.inc:190
void setFieldName(const ::mlir::FlatSymbolRefAttr &propValue)
Definition Ops.h.inc:449