LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
24
25#include <mlir/IR/IRMapping.h>
26#include <mlir/IR/OpImplementation.h>
27#include <mlir/Interfaces/FunctionImplementation.h>
28
29#include <llvm/ADT/MapVector.h>
30
31// TableGen'd implementation files
32#define GET_OP_CLASSES
34
35using namespace mlir;
36using namespace llzk::component;
37using namespace llzk::polymorphic;
38
39namespace llzk::function {
40
41FunctionKind fnNameToKind(mlir::StringRef name) {
42 if (FUNC_NAME_COMPUTE == name) {
44 } else if (FUNC_NAME_CONSTRAIN == name) {
46 } else if (FUNC_NAME_PRODUCT == name) {
48 } else {
49 return FunctionKind::Free;
50 }
51}
52
53namespace {
55inline LogicalResult
56verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
58 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
59 );
60}
61} // namespace
62
63//===----------------------------------------------------------------------===//
64// FuncDefOp
65//===----------------------------------------------------------------------===//
66
68 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
69) {
70 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
71}
72
74 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
75) {
76 SmallVector<NamedAttribute, 8> attrRef(attrs);
77 return create(location, name, type, llvm::ArrayRef(attrRef));
78}
79
81 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
82 ArrayRef<DictionaryAttr> argAttrs
83) {
84 FuncDefOp func = create(location, name, type, attrs);
85 func.setAllArgAttrs(argAttrs);
86 return func;
87}
88
90 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
91 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
92) {
93 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
94 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
95 state.attributes.append(attrs.begin(), attrs.end());
96 state.addRegion();
97
98 if (argAttrs.empty()) {
99 return;
100 }
101 assert(type.getNumInputs() == argAttrs.size());
102 function_interface_impl::addArgAndResultAttrs(
103 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
104 getResAttrsAttrName(state.name)
105 );
106}
107
108ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
109 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
110 function_interface_impl::VariadicFlag,
111 std::string &) { return builder.getFunctionType(argTypes, results); };
112
113 return function_interface_impl::parseFunctionOp(
114 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
115 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
116 );
117}
118
119void FuncDefOp::print(OpAsmPrinter &p) {
120 function_interface_impl::printFunctionOp(
121 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
123 );
124}
125
128void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
129 // Add the attributes of this function to dest.
130 llvm::MapVector<StringAttr, Attribute> newAttrMap;
131 for (const auto &attr : dest->getAttrs()) {
132 newAttrMap.insert({attr.getName(), attr.getValue()});
133 }
134 for (const auto &attr : (*this)->getAttrs()) {
135 newAttrMap.insert({attr.getName(), attr.getValue()});
136 }
137
138 auto newAttrs =
139 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
140 return NamedAttribute(attrPair.first, attrPair.second);
141 }));
142 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
143
144 // Clone the body.
145 getBody().cloneInto(&dest.getBody(), mapper);
146}
147
153FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
154 // Create the new function.
155 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
156
157 // If the function has a body, then the user might be deleting arguments to
158 // the function by specifying them in the mapper. If so, we don't add the
159 // argument to the input type vector.
160 if (!isExternal()) {
161 FunctionType oldType = getFunctionType();
162
163 unsigned oldNumArgs = oldType.getNumInputs();
164 SmallVector<Type, 4> newInputs;
165 newInputs.reserve(oldNumArgs);
166 for (unsigned i = 0; i != oldNumArgs; ++i) {
167 if (!mapper.contains(getArgument(i))) {
168 newInputs.push_back(oldType.getInput(i));
169 }
170 }
171
174 if (newInputs.size() != oldNumArgs) {
175 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
176
177 if (ArrayAttr argAttrs = getAllArgAttrs()) {
178 SmallVector<Attribute> newArgAttrs;
179 newArgAttrs.reserve(newInputs.size());
180 for (unsigned i = 0; i != oldNumArgs; ++i) {
181 if (!mapper.contains(getArgument(i))) {
182 newArgAttrs.push_back(argAttrs[i]);
183 }
184 }
185 newFunc.setAllArgAttrs(newArgAttrs);
186 }
187 }
188 }
189
191 cloneInto(newFunc, mapper);
192 return newFunc;
193}
194
196 IRMapping mapper;
197 return clone(mapper);
198}
199
201 if (newValue) {
202 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
203 } else {
204 getOperation()->removeAttr(AllowConstraintAttr::name);
205 }
206}
207
209 if (newValue) {
210 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
211 } else {
212 getOperation()->removeAttr(AllowWitnessAttr::name);
213 }
214}
215
216bool FuncDefOp::hasArgPublicAttr(unsigned index) {
217 if (index < this->getNumArguments()) {
218 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
219 return res ? res.contains(PublicAttr::name) : false;
220 } else {
221 // TODO: print error? requested attribute for non-existant argument index
222 return false;
223 }
224}
225
226LogicalResult FuncDefOp::verify() {
227 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
228 // Ensure that only valid LLZK types are used for arguments and return. Additionally, the struct
229 // functions may not use AffineMapAttrs in their parameter types. If such a scenario seems to make
230 // sense when generating LLZK IR, it's likely better to introduce a struct parameter to use
231 // instead and instantiate the struct with that AffineMapAttr.
232 FunctionType type = getFunctionType();
233 for (Type t : type.getInputs()) {
234 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
235 return failure();
236 }
237 if (isInStruct() && hasAffineMapAttr(t)) {
238 return emitErrorFunc().append(
239 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", t
240 );
241 }
242 }
243 for (Type t : type.getResults()) {
244 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
245 return failure();
246 }
247 }
248 // Ensure that the function does not contain nested modules.
249 // Functions also cannot contain nested structs, but this check is handled
250 // via struct.def's requirement of having module as a parent.
251 WalkResult res = this->walk<WalkOrder::PreOrder>([this](ModuleOp nestedMod) {
252 getEmitOpErrFn(nestedMod)().append(
253 "cannot be nested within '", getOperation()->getName(), "' operations"
254 );
255 return WalkResult::interrupt();
256 });
257 if (res.wasInterrupted()) {
258 return failure();
259 }
260
261 return success();
262}
263
264namespace {
265
266LogicalResult
267verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
268 FunctionType funcType = origin.getFunctionType();
269 llvm::ArrayRef<Type> resTypes = funcType.getResults();
270 // Must return type of parent struct
271 if (resTypes.size() != 1) {
272 return origin.emitOpError().append(
273 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
274 );
275 }
276 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
277 return failure();
278 }
279
280 // After the more specific checks (to ensure more specific error messages would be produced if
281 // necessary), do the general check that all symbol references in the types are valid. The return
282 // types were already checked so just check the input types.
283 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
284}
285
286LogicalResult
287verifyFuncTypeProduct(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
288 // The signature for @product is the same as the signature for @compute
289 return verifyFuncTypeCompute(origin, tables, parent);
290}
291
292LogicalResult
293verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
294 FunctionType funcType = origin.getFunctionType();
295 // Must return '()' type, i.e., have no return types
296 if (funcType.getResults().size() != 0) {
297 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
298 }
299
300 // Type of the first parameter must match the parent StructDefOp of the current operation.
301 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
302 if (inputTypes.size() < 1) {
303 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
304 << "\" must have at least one input type";
305 }
306 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
307 return failure();
308 }
309
310 // After the more specific checks (to ensure more specific error messages would be produced if
311 // necessary), do the general check that all symbol references in the types are valid. There are
312 // no return types, just check the remaining input types (the first was already checked via
313 // the checkSelfType() call above).
314 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
315}
316
317} // namespace
318
319LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
320 // Additional checks for the compute/constrain/product functions within a struct
321 FailureOr<StructDefOp> parentStructOpt = getParentOfType<StructDefOp>(*this);
322 if (succeeded(parentStructOpt)) {
323 // Verify return type restrictions for functions within a StructDefOp
324 if (nameIsCompute()) {
325 return verifyFuncTypeCompute(*this, tables, parentStructOpt.value());
326 } else if (nameIsConstrain()) {
327 return verifyFuncTypeConstrain(*this, tables, parentStructOpt.value());
328 } else if (nameIsProduct()) {
329 return verifyFuncTypeProduct(*this, tables, parentStructOpt.value());
330 }
331 }
332 // In the general case, verify symbol resolution in all input and output types.
333 return verifyTypeResolution(tables, *this, getFunctionType());
334}
335
336SymbolRefAttr FuncDefOp::getFullyQualifiedName(bool requireParent) {
337 // If the parent is not present and not required, just return the symbol name
338 if (!requireParent && getOperation()->getParentOp() == nullptr) {
339 return SymbolRefAttr::get(getOperation());
340 }
341 auto res = getPathFromRoot(*this);
342 assert(succeeded(res));
343 return res.value();
344}
345
347 assert(nameIsCompute()); // skip inStruct check to allow dangling functions
348 // Get the single block of the function body
349 Region &body = getBody();
350 assert(!body.empty() && "compute() function body is empty");
351 Block &block = body.back();
352
353 // The terminator should be the return op
354 Operation *terminator = block.getTerminator();
355 assert(terminator && "compute() function has no terminator");
356 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
357 if (!retOp) {
358 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
359 << terminator->getName() << "'\n";
360 llvm_unreachable("compute() function must end with ReturnOp");
361 }
362 return retOp.getOperands().front();
363}
364
366 assert(nameIsConstrain()); // skip inStruct check to allow dangling functions
367 return getArguments().front();
368}
369
371 assert(isStructCompute() && "violated implementation pre-condition");
373}
374
375//===----------------------------------------------------------------------===//
376// ReturnOp
377//===----------------------------------------------------------------------===//
378
379LogicalResult ReturnOp::verify() {
380 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
381
382 // The operand number and types must match the function signature.
383 const auto results = function.getFunctionType().getResults();
384 if (getNumOperands() != results.size()) {
385 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
386 << function.getName() << ") returns " << results.size();
387 }
388
389 for (unsigned i = 0, e = results.size(); i != e; ++i) {
390 if (!typesUnify(getOperand(i).getType(), results[i])) {
391 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
392 << ") doesn't match function result type (" << results[i] << ")"
393 << " in function @" << function.getName();
394 }
395 }
396
397 return success();
398}
399
400//===----------------------------------------------------------------------===//
401// CallOp
402//===----------------------------------------------------------------------===//
403
404void CallOp::build(
405 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
406 ValueRange argOperands
407) {
408 odsState.addTypes(resultTypes);
409 odsState.addOperands(argOperands);
411 odsBuilder, odsState, static_cast<int32_t>(argOperands.size())
412 );
413 props.setCallee(callee);
414}
415
416void CallOp::build(
417 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
418 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
419) {
420 odsState.addTypes(resultTypes);
421 odsState.addOperands(argOperands);
423 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
424 );
425 props.setCallee(callee);
426}
427
428namespace {
429
430struct CallOpVerifier {
431 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(fnNameToKind(tgtName)) {}
432 virtual ~CallOpVerifier() = default;
433
434 LogicalResult verify() {
435 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
436 // provide as many errors are possible in a single verifier run.
437 LogicalResult aggregateResult = success();
438 if (failed(verifyTargetAttributes())) {
439 aggregateResult = failure();
440 }
441 if (failed(verifyInputs())) {
442 aggregateResult = failure();
443 }
444 if (failed(verifyOutputs())) {
445 aggregateResult = failure();
446 }
447 if (failed(verifyAffineMapParams())) {
448 aggregateResult = failure();
449 }
450 return aggregateResult;
451 }
452
453protected:
454 CallOp *callOp;
455 FunctionKind tgtKind;
456
457 virtual LogicalResult verifyTargetAttributes() = 0;
458 virtual LogicalResult verifyInputs() = 0;
459 virtual LogicalResult verifyOutputs() = 0;
460 virtual LogicalResult verifyAffineMapParams() = 0;
461
463 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
464 LogicalResult aggregateRes = success();
465 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
466 auto emitAttrErr = [&](StringLiteral attrName) {
467 aggregateRes = callOp->emitOpError()
468 << "target '@" << target.getName() << "' has '" << attrName
469 << "' attribute, which is not specified by the caller '@" << caller.getName()
470 << '\'';
471 };
472
473 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
474 emitAttrErr(AllowConstraintAttr::name);
475 }
476 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
477 emitAttrErr(AllowWitnessAttr::name);
478 }
479 }
480 return aggregateRes;
481 }
482
483 LogicalResult verifyNoAffineMapInstantiations() {
484 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
485 // Tested in call_with_affinemap_fail.llzk
486 return callOp->emitOpError().append(
487 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
488 "\" function"
489 );
490 }
491 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
492 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
493 assert(callOp->getMapOperands().empty());
494 return success();
495 }
496};
497
498struct KnownTargetVerifier : public CallOpVerifier {
499 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
500 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
501 includeSymNames(tgtRes.getIncludeSymNames()) {}
502
503 LogicalResult verifyTargetAttributes() override {
504 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
505 }
506
507 LogicalResult verifyInputs() override {
508 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
509 }
510
511 LogicalResult verifyOutputs() override {
512 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
513 }
514
515 LogicalResult verifyAffineMapParams() override {
516 if ((FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) &&
517 isInStruct(tgt.getOperation())) {
518 // Return type should be a single StructType. If that is not the case here, just bail without
519 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
520 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
521 // a single StructType return value will produce a more relevant error message in that case.
522 if (StructType retTy = callOp->getSingleResultTypeOfWitnessGen()) {
523 if (ArrayAttr params = retTy.getParams()) {
524 // Collect the struct parameters that are defined via AffineMapAttr
525 SmallVector<AffineMapAttr> mapAttrs;
526 for (Attribute a : params) {
527 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
528 mapAttrs.push_back(m);
529 }
530 }
532 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
533 );
534 }
535 }
536 return success();
537 } else {
538 // Global functions and constrain functions cannot have affine map instantiations.
539 return verifyNoAffineMapInstantiations();
540 }
541 }
542
543private:
544 template <typename T>
545 LogicalResult
546 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
547 if (tgtTypes.size() != callOpTypes.size()) {
548 return callOp->emitOpError()
549 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
550 .attachNote(tgt.getLoc())
551 .append("callee defined here");
552 }
553 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
554 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
555 return callOp->emitOpError().append(
556 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
557 " for ", aspect, " number ", i
558 );
559 }
560 }
561 return success();
562 }
563
564 FuncDefOp tgt;
565 FunctionType tgtType;
566 std::vector<llvm::StringRef> includeSymNames;
567};
568
571LogicalResult checkSelfTypeUnknownTarget(
572 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
573) {
574 if (!llvm::isa<TypeVarType>(actualType) ||
575 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
576 // Tested in function_restrictions_fail.llzk:
577 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
578 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
579 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
580 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
581 return origin->emitOpError().append(
582 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
583 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
584 actualType
585 );
586 }
587 return success();
588}
589
599struct UnknownTargetVerifier : public CallOpVerifier {
600 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
601 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
602
603 LogicalResult verifyTargetAttributes() override {
604 // Based on the precondition of this verifier, the target must be either a
605 // struct compute, constrain, or product function.
606 LogicalResult aggregateRes = success();
607 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
608 auto emitAttrErr = [&](StringLiteral attrName) {
609 aggregateRes = callOp->emitOpError()
610 << "target '" << calleeAttr << "' has '" << attrName
611 << "' attribute, which is not specified by the caller '@" << caller.getName()
612 << '\'';
613 };
614
615 switch (tgtKind) {
617 if (!caller.hasAllowConstraintAttr()) {
618 emitAttrErr(AllowConstraintAttr::name);
619 }
620 break;
622 if (!caller.hasAllowWitnessAttr()) {
623 emitAttrErr(AllowWitnessAttr::name);
624 }
625 break;
627 if (!caller.hasAllowWitnessAttr()) {
628 emitAttrErr(AllowWitnessAttr::name);
629 }
630 if (!caller.hasAllowConstraintAttr()) {
631 emitAttrErr(AllowConstraintAttr::name);
632 }
633 break;
634 default:
635 break;
636 }
637 }
638 return aggregateRes;
639 }
640
641 LogicalResult verifyInputs() override {
642 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
643 // Without known target, no additional checks can be done.
644 } else if (FunctionKind::StructConstrain == tgtKind) {
645 // Without known target, this can only check that the first input is VarType using the same
646 // struct parameter as the base of the callee (later replaced with the target struct's type).
647 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
648 if (inputTypes.size() < 1) {
649 // Tested in function_restrictions_fail.llzk
650 return callOp->emitOpError()
651 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
652 }
653 return checkSelfTypeUnknownTarget(
654 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
655 );
656 }
657 return success();
658 }
659
660 LogicalResult verifyOutputs() override {
661 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
662 // Without known target, this can only check that the function returns VarType using the same
663 // struct parameter as the base of the callee (later replaced with the target struct's type).
664 Operation::result_type_range resTypes = callOp->getResultTypes();
665 if (resTypes.size() != 1) {
666 // Tested in function_restrictions_fail.llzk
667 return callOp->emitOpError().append(
668 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
669 );
670 }
671 return checkSelfTypeUnknownTarget(
672 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
673 );
674 } else if (FunctionKind::StructConstrain == tgtKind) {
675 // Without known target, this can only check that the function has no return
676 if (callOp->getNumResults() != 0) {
677 // Tested in function_restrictions_fail.llzk
678 return callOp->emitOpError()
679 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
680 }
681 }
682 return success();
683 }
684
685 LogicalResult verifyAffineMapParams() override {
686 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
687 // Without known target, no additional checks can be done.
688 } else if (FunctionKind::StructConstrain == tgtKind) {
689 // Without known target, this can only check that there are no affine map instantiations.
690 return verifyNoAffineMapInstantiations();
691 }
692 return success();
693 }
694
695private:
696 SymbolRefAttr calleeAttr;
697};
698
699} // namespace
700
701LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
702 // First, verify symbol resolution in all input and output types.
703 if (failed(verifyTypeResolution(tables, *this, getCalleeType()))) {
704 return failure(); // verifyTypeResolution() already emits a sufficient error message
705 }
706
707 // Check that the callee attribute was specified.
708 SymbolRefAttr calleeAttr = getCalleeAttr();
709 if (!calleeAttr) {
710 return emitOpError("requires a 'callee' symbol reference attribute");
711 }
712
713 // If the callee references a parameter of the struct where this call appears, perform the subset
714 // of checks that can be done even though the target is unknown.
715 if (calleeAttr.getNestedReferences().size() == 1) {
716 FailureOr<StructDefOp> parent = getParentOfType<StructDefOp>(*this);
717 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
718 return UnknownTargetVerifier(this, calleeAttr).verify();
719 }
720 }
721
722 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
723 // checks against the known target function.
724 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
725 if (failed(tgtOpt)) {
726 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
727 << calleeAttr << '"';
728 }
729 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
730}
731
732FunctionType CallOp::getCalleeType() {
733 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
734}
735
736namespace {
737
738bool calleeIsStructFunctionImpl(
739 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
740) {
741 if (callee.getLeafReference() == funcName) {
742 if (StructType t = getType()) {
743 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
744 // name itself), then the `callee` target must be within a StructDefOp because validation
745 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
746 // function is not simply a global function nested within a ModuleOp)
747 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
748 }
749 }
750 return false;
751}
752
753} // namespace
754
756 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
757 return this->getSingleResultTypeOfCompute();
758 });
759}
760
762 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
763 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
764 });
765}
766
768 assert(calleeIsStructCompute());
769 return getResults().front();
770}
771
773 assert(calleeIsStructConstrain());
774 return getArgOperands().front();
775}
776
777FailureOr<SymbolLookupResult<FuncDefOp>> CallOp::getCalleeTarget(SymbolTableCollection &tables) {
778 Operation *thisOp = this->getOperation();
779 auto root = getRootModule(thisOp);
780 assert(succeeded(root));
781 return llzk::lookupSymbolIn<FuncDefOp>(tables, getCallee(), root->getOperation(), thisOp);
782}
783
785 assert(calleeIsCompute() && "violated implementation pre-condition");
786 return getIfSingleton<StructType>(getResultTypes());
787}
788
790 assert(calleeContainsWitnessGen() && "violated implementation pre-condition");
791 return getIfSingleton<StructType>(getResultTypes());
792}
793
795CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
796
798void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
799 setCalleeAttr(llvm::cast<SymbolRefAttr>(callee));
800}
801
802SmallVector<ValueRange> CallOp::toVectorOfValueRange(OperandRangeRange input) {
803 llvm::SmallVector<ValueRange, 4> output;
804 output.reserve(input.size());
805 for (OperandRange r : input) {
806 output.push_back(r);
807 }
808 return output;
809}
810
811} // namespace llzk::function
This file defines methods symbol lookup across LLZK operations and included files.
bool calleeContainsWitnessGen()
Return true iff the callee function can contain witness generation code (this does not check if the c...
Definition Ops.h.inc:335
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:761
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:795
::mlir::FunctionType getCalleeType()
Definition Ops.cpp:732
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:784
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:701
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:755
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:329
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={})
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:472
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:772
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:802
::llzk::component::StructType getSingleResultTypeOfWitnessGen()
Assuming the callee contains witness generation code, return the single StructType result.
Definition Ops.cpp:789
FoldAdaptor::Properties Properties
Definition Ops.h.inc:192
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:282
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:767
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:798
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:777
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:208
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:319
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:346
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:559
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:365
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:119
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:770
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:729
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:370
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:567
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:108
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:128
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:778
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:774
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:784
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:781
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Returns the result types of this function.
Definition Ops.h.inc:752
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:583
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:200
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:336
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:216
::llvm::LogicalResult verify()
Definition Ops.cpp:226
::mlir::Region & getBody()
Definition Ops.h.inc:607
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:721
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:551
::llvm::LogicalResult verify()
Definition Ops.cpp:379
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:885
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:41
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:108
FunctionKind fnNameToKind(mlir::StringRef name)
Given a function name, return the corresponding FunctionKind.
Definition Ops.cpp:41
FunctionKind
Kinds of functions in LLZK.
Definition Ops.h:32
@ StructConstrain
Function within a struct named FUNC_NAME_CONSTRAIN.
Definition Ops.h:36
@ StructProduct
Function within a struct named FUNC_NAME_PRODUCT.
Definition Ops.h:38
@ StructCompute
Function within a struct named FUNC_NAME_COMPUTE.
Definition Ops.h:34
@ Free
Function that is not within a struct.
Definition Ops.h:40
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:255
FailureOr< ModuleOp > getRootModule(Operation *from)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:29
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:259
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)