LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
23
24#include <mlir/IR/IRMapping.h>
25#include <mlir/IR/OpImplementation.h>
26#include <mlir/Interfaces/FunctionImplementation.h>
27
28#include <llvm/ADT/MapVector.h>
29
30// TableGen'd implementation files
31#define GET_OP_CLASSES
33
34using namespace mlir;
35using namespace llzk::component;
36using namespace llzk::polymorphic;
37
38namespace llzk::function {
39
40namespace {
42inline LogicalResult
43verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
45 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
46 );
47}
48} // namespace
49
50//===----------------------------------------------------------------------===//
51// FuncDefOp
52//===----------------------------------------------------------------------===//
53
55 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
56) {
57 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
58}
59
61 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
62) {
63 SmallVector<NamedAttribute, 8> attrRef(attrs);
64 return create(location, name, type, llvm::ArrayRef(attrRef));
65}
66
68 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
69 ArrayRef<DictionaryAttr> argAttrs
70) {
71 FuncDefOp func = create(location, name, type, attrs);
72 func.setAllArgAttrs(argAttrs);
73 return func;
74}
75
77 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
78 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
79) {
80 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
81 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
82 state.attributes.append(attrs.begin(), attrs.end());
83 state.addRegion();
84
85 if (argAttrs.empty()) {
86 return;
87 }
88 assert(type.getNumInputs() == argAttrs.size());
89 function_interface_impl::addArgAndResultAttrs(
90 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
91 getResAttrsAttrName(state.name)
92 );
93}
94
95ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
96 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
97 function_interface_impl::VariadicFlag,
98 std::string &) { return builder.getFunctionType(argTypes, results); };
99
100 return function_interface_impl::parseFunctionOp(
101 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
102 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
103 );
104}
105
106void FuncDefOp::print(OpAsmPrinter &p) {
107 function_interface_impl::printFunctionOp(
108 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
110 );
111}
112
115void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
116 // Add the attributes of this function to dest.
117 llvm::MapVector<StringAttr, Attribute> newAttrMap;
118 for (const auto &attr : dest->getAttrs()) {
119 newAttrMap.insert({attr.getName(), attr.getValue()});
120 }
121 for (const auto &attr : (*this)->getAttrs()) {
122 newAttrMap.insert({attr.getName(), attr.getValue()});
123 }
124
125 auto newAttrs =
126 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
127 return NamedAttribute(attrPair.first, attrPair.second);
128 }));
129 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
130
131 // Clone the body.
132 getBody().cloneInto(&dest.getBody(), mapper);
133}
134
140FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
141 // Create the new function.
142 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
143
144 // If the function has a body, then the user might be deleting arguments to
145 // the function by specifying them in the mapper. If so, we don't add the
146 // argument to the input type vector.
147 if (!isExternal()) {
148 FunctionType oldType = getFunctionType();
149
150 unsigned oldNumArgs = oldType.getNumInputs();
151 SmallVector<Type, 4> newInputs;
152 newInputs.reserve(oldNumArgs);
153 for (unsigned i = 0; i != oldNumArgs; ++i) {
154 if (!mapper.contains(getArgument(i))) {
155 newInputs.push_back(oldType.getInput(i));
156 }
157 }
158
161 if (newInputs.size() != oldNumArgs) {
162 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
163
164 if (ArrayAttr argAttrs = getAllArgAttrs()) {
165 SmallVector<Attribute> newArgAttrs;
166 newArgAttrs.reserve(newInputs.size());
167 for (unsigned i = 0; i != oldNumArgs; ++i) {
168 if (!mapper.contains(getArgument(i))) {
169 newArgAttrs.push_back(argAttrs[i]);
170 }
171 }
172 newFunc.setAllArgAttrs(newArgAttrs);
173 }
174 }
175 }
176
178 cloneInto(newFunc, mapper);
179 return newFunc;
180}
181
183 IRMapping mapper;
184 return clone(mapper);
185}
186
188 if (newValue) {
189 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
190 } else {
191 getOperation()->removeAttr(AllowConstraintAttr::name);
192 }
193}
194
196 if (newValue) {
197 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
198 } else {
199 getOperation()->removeAttr(AllowWitnessAttr::name);
200 }
201}
202
203bool FuncDefOp::hasArgPublicAttr(unsigned index) {
204 if (index < this->getNumArguments()) {
205 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
206 return res ? res.contains(PublicAttr::name) : false;
207 } else {
208 // TODO: print error? requested attribute for non-existant argument index
209 return false;
210 }
211}
212
213LogicalResult FuncDefOp::verify() {
214 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
215 // Ensure that only valid LLZK types are used for arguments and return. Additionally, the struct
216 // functions may not use AffineMapAttrs in their parameter types. If such a scenario seems to make
217 // sense when generating LLZK IR, it's likely better to introduce a struct parameter to use
218 // instead and instantiate the struct with that AffineMapAttr.
219 FunctionType type = getFunctionType();
220 for (Type t : type.getInputs()) {
221 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
222 return failure();
223 }
224 if (isInStruct() && hasAffineMapAttr(t)) {
225 return emitErrorFunc().append(
226 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", t
227 );
228 }
229 }
230 for (Type t : type.getResults()) {
231 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
232 return failure();
233 }
234 }
235 // Ensure that the function does not contain nested modules.
236 // Functions also cannot contain nested structs, but this check is handled
237 // via struct.def's requirement of having module as a parent.
238 WalkResult res = this->walk<WalkOrder::PreOrder>([this](ModuleOp nestedMod) {
239 getEmitOpErrFn(nestedMod)().append(
240 "cannot be nested within '", getOperation()->getName(), "' operations"
241 );
242 return WalkResult::interrupt();
243 });
244 if (res.wasInterrupted()) {
245 return failure();
246 }
247
248 return success();
249}
250
251namespace {
252
253LogicalResult
254verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
255 FunctionType funcType = origin.getFunctionType();
256 llvm::ArrayRef<Type> resTypes = funcType.getResults();
257 // Must return type of parent struct
258 if (resTypes.size() != 1) {
259 return origin.emitOpError().append(
260 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
261 );
262 }
263 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
264 return failure();
265 }
266
267 // After the more specific checks (to ensure more specific error messages would be produced if
268 // necessary), do the general check that all symbol references in the types are valid. The return
269 // types were already checked so just check the input types.
270 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
271}
272
273LogicalResult
274verifyFuncTypeProduct(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
275 // The signature for @product is the same as the signature for @compute
276 return verifyFuncTypeCompute(origin, tables, parent);
277}
278
279LogicalResult
280verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
281 FunctionType funcType = origin.getFunctionType();
282 // Must return '()' type, i.e., have no return types
283 if (funcType.getResults().size() != 0) {
284 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
285 }
286
287 // Type of the first parameter must match the parent StructDefOp of the current operation.
288 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
289 if (inputTypes.size() < 1) {
290 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
291 << "\" must have at least one input type";
292 }
293 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
294 return failure();
295 }
296
297 // After the more specific checks (to ensure more specific error messages would be produced if
298 // necessary), do the general check that all symbol references in the types are valid. There are
299 // no return types, just check the remaining input types (the first was already checked via
300 // the checkSelfType() call above).
301 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
302}
303
304} // namespace
305
306LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
307 // Additional checks for the compute/constrain/product functions within a struct
308 FailureOr<StructDefOp> parentStructOpt = getParentOfType<StructDefOp>(*this);
309 if (succeeded(parentStructOpt)) {
310 // Verify return type restrictions for functions within a StructDefOp
311 if (nameIsCompute()) {
312 return verifyFuncTypeCompute(*this, tables, parentStructOpt.value());
313 } else if (nameIsConstrain()) {
314 return verifyFuncTypeConstrain(*this, tables, parentStructOpt.value());
315 } else if (nameIsProduct()) {
316 return verifyFuncTypeProduct(*this, tables, parentStructOpt.value());
317 }
318 }
319 // In the general case, verify symbol resolution in all input and output types.
320 return verifyTypeResolution(tables, *this, getFunctionType());
321}
322
323SymbolRefAttr FuncDefOp::getFullyQualifiedName(bool requireParent) {
324 // If the parent is not present and not required, just return the symbol name
325 if (!requireParent && getOperation()->getParentOp() == nullptr) {
326 return SymbolRefAttr::get(getOperation());
327 }
328 auto res = getPathFromRoot(*this);
329 assert(succeeded(res));
330 return res.value();
331}
332
334 assert(nameIsCompute()); // skip inStruct check to allow dangling functions
335 // Get the single block of the function body
336 Region &body = getBody();
337 assert(!body.empty() && "compute() function body is empty");
338 Block &block = body.back();
339
340 // The terminator should be the return op
341 Operation *terminator = block.getTerminator();
342 assert(terminator && "compute() function has no terminator");
343 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
344 if (!retOp) {
345 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
346 << terminator->getName() << "'\n";
347 llvm_unreachable("compute() function must end with ReturnOp");
348 }
349 return retOp.getOperands().front();
350}
351
353 assert(nameIsConstrain()); // skip inStruct check to allow dangling functions
354 return getArguments().front();
355}
356
358 assert(isStructCompute() && "violated implementation pre-condition");
360}
361
362//===----------------------------------------------------------------------===//
363// ReturnOp
364//===----------------------------------------------------------------------===//
365
366LogicalResult ReturnOp::verify() {
367 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
368
369 // The operand number and types must match the function signature.
370 const auto results = function.getFunctionType().getResults();
371 if (getNumOperands() != results.size()) {
372 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
373 << function.getName() << ") returns " << results.size();
374 }
375
376 for (unsigned i = 0, e = results.size(); i != e; ++i) {
377 if (!typesUnify(getOperand(i).getType(), results[i])) {
378 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
379 << ") doesn't match function result type (" << results[i] << ")"
380 << " in function @" << function.getName();
381 }
382 }
383
384 return success();
385}
386
387//===----------------------------------------------------------------------===//
388// CallOp
389//===----------------------------------------------------------------------===//
390
391void CallOp::build(
392 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
393 ValueRange argOperands
394) {
395 odsState.addTypes(resultTypes);
396 odsState.addOperands(argOperands);
398 odsBuilder, odsState, static_cast<int32_t>(argOperands.size())
399 );
400 props.setCallee(callee);
401}
402
403void CallOp::build(
404 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
405 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
406) {
407 odsState.addTypes(resultTypes);
408 odsState.addOperands(argOperands);
410 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
411 );
412 props.setCallee(callee);
413}
414
415namespace {
416enum class CalleeKind : std::uint8_t { Compute, Constrain, Product, Other };
417
418CalleeKind calleeNameToKind(StringRef tgtName) {
419 if (FUNC_NAME_COMPUTE == tgtName) {
420 return CalleeKind::Compute;
421 } else if (FUNC_NAME_CONSTRAIN == tgtName) {
422 return CalleeKind::Constrain;
423 } else if (FUNC_NAME_PRODUCT == tgtName) {
424 return CalleeKind::Product;
425 } else {
426 return CalleeKind::Other;
427 }
428}
429
430struct CallOpVerifier {
431 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(calleeNameToKind(tgtName)) {}
432 virtual ~CallOpVerifier() = default;
433
434 LogicalResult verify() {
435 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
436 // provide as many errors are possible in a single verifier run.
437 LogicalResult aggregateResult = success();
438 if (failed(verifyTargetAttributes())) {
439 aggregateResult = failure();
440 }
441 if (failed(verifyInputs())) {
442 aggregateResult = failure();
443 }
444 if (failed(verifyOutputs())) {
445 aggregateResult = failure();
446 }
447 if (failed(verifyAffineMapParams())) {
448 aggregateResult = failure();
449 }
450 return aggregateResult;
451 }
452
453protected:
454 CallOp *callOp;
455 CalleeKind tgtKind;
456
457 virtual LogicalResult verifyTargetAttributes() = 0;
458 virtual LogicalResult verifyInputs() = 0;
459 virtual LogicalResult verifyOutputs() = 0;
460 virtual LogicalResult verifyAffineMapParams() = 0;
461
463 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
464 LogicalResult aggregateRes = success();
465 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
466 auto emitAttrErr = [&](StringLiteral attrName) {
467 aggregateRes = callOp->emitOpError()
468 << "target '@" << target.getName() << "' has '" << attrName
469 << "' attribute, which is not specified by the caller '@" << caller.getName()
470 << '\'';
471 };
472
473 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
474 emitAttrErr(AllowConstraintAttr::name);
475 }
476 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
477 emitAttrErr(AllowWitnessAttr::name);
478 }
479 }
480 return aggregateRes;
481 }
482
483 LogicalResult verifyNoAffineMapInstantiations() {
484 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
485 // Tested in call_with_affinemap_fail.llzk
486 return callOp->emitOpError().append(
487 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
488 "\" function"
489 );
490 }
491 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
492 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
493 assert(callOp->getMapOperands().empty());
494 return success();
495 }
496};
497
498struct KnownTargetVerifier : public CallOpVerifier {
499 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
500 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
501 includeSymNames(tgtRes.getIncludeSymNames()) {}
502
503 LogicalResult verifyTargetAttributes() override {
504 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
505 }
506
507 LogicalResult verifyInputs() override {
508 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
509 }
510
511 LogicalResult verifyOutputs() override {
512 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
513 }
514
515 LogicalResult verifyAffineMapParams() override {
516 if ((CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) &&
517 isInStruct(tgt.getOperation())) {
518 // Return type should be a single StructType. If that is not the case here, just bail without
519 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
520 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
521 // a single StructType return value will produce a more relevant error message in that case.
522 if (StructType retTy = callOp->getSingleResultTypeOfCompute()) {
523 if (ArrayAttr params = retTy.getParams()) {
524 // Collect the struct parameters that are defined via AffineMapAttr
525 SmallVector<AffineMapAttr> mapAttrs;
526 for (Attribute a : params) {
527 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
528 mapAttrs.push_back(m);
529 }
530 }
532 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
533 );
534 }
535 }
536 return success();
537 } else {
538 // Global functions and constrain functions cannot have affine map instantiations.
539 return verifyNoAffineMapInstantiations();
540 }
541 }
542
543private:
544 template <typename T>
545 LogicalResult
546 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
547 if (tgtTypes.size() != callOpTypes.size()) {
548 return callOp->emitOpError()
549 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
550 .attachNote(tgt.getLoc())
551 .append("callee defined here");
552 }
553 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
554 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
555 return callOp->emitOpError().append(
556 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
557 " for ", aspect, " number ", i
558 );
559 }
560 }
561 return success();
562 }
563
564 FuncDefOp tgt;
565 FunctionType tgtType;
566 std::vector<llvm::StringRef> includeSymNames;
567};
568
571LogicalResult checkSelfTypeUnknownTarget(
572 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
573) {
574 if (!llvm::isa<TypeVarType>(actualType) ||
575 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
576 // Tested in function_restrictions_fail.llzk:
577 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
578 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
579 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
580 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
581 return origin->emitOpError().append(
582 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
583 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
584 actualType
585 );
586 }
587 return success();
588}
589
599struct UnknownTargetVerifier : public CallOpVerifier {
600 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
601 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
602
603 LogicalResult verifyTargetAttributes() override {
604 // Based on the precondition of this verifier, the target must be either a
605 // struct compute, constrain, or product function.
606 LogicalResult aggregateRes = success();
607 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
608 auto emitAttrErr = [&](StringLiteral attrName) {
609 aggregateRes = callOp->emitOpError()
610 << "target '" << calleeAttr << "' has '" << attrName
611 << "' attribute, which is not specified by the caller '@" << caller.getName()
612 << '\'';
613 };
614
615 switch (tgtKind) {
616 case CalleeKind::Constrain:
617 if (!caller.hasAllowConstraintAttr()) {
618 emitAttrErr(AllowConstraintAttr::name);
619 }
620 break;
621 case CalleeKind::Compute:
622 if (!caller.hasAllowWitnessAttr()) {
623 emitAttrErr(AllowWitnessAttr::name);
624 }
625 break;
626 case CalleeKind::Product:
627 if (!caller.hasAllowWitnessAttr()) {
628 emitAttrErr(AllowWitnessAttr::name);
629 }
630 if (!caller.hasAllowConstraintAttr()) {
631 emitAttrErr(AllowConstraintAttr::name);
632 }
633 break;
634 default:
635 break;
636 }
637 }
638 return aggregateRes;
639 }
640
641 LogicalResult verifyInputs() override {
642 if (CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) {
643 // Without known target, no additional checks can be done.
644 } else if (CalleeKind::Constrain == tgtKind) {
645 // Without known target, this can only check that the first input is VarType using the same
646 // struct parameter as the base of the callee (later replaced with the target struct's type).
647 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
648 if (inputTypes.size() < 1) {
649 // Tested in function_restrictions_fail.llzk
650 return callOp->emitOpError()
651 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
652 }
653 return checkSelfTypeUnknownTarget(
654 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
655 );
656 }
657 return success();
658 }
659
660 LogicalResult verifyOutputs() override {
661 if (CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) {
662 // Without known target, this can only check that the function returns VarType using the same
663 // struct parameter as the base of the callee (later replaced with the target struct's type).
664 Operation::result_type_range resTypes = callOp->getResultTypes();
665 if (resTypes.size() != 1) {
666 // Tested in function_restrictions_fail.llzk
667 return callOp->emitOpError().append(
668 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
669 );
670 }
671 return checkSelfTypeUnknownTarget(
672 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
673 );
674 } else if (CalleeKind::Constrain == tgtKind) {
675 // Without known target, this can only check that the function has no return
676 if (callOp->getNumResults() != 0) {
677 // Tested in function_restrictions_fail.llzk
678 return callOp->emitOpError()
679 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
680 }
681 }
682 return success();
683 }
684
685 LogicalResult verifyAffineMapParams() override {
686 if (CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) {
687 // Without known target, no additional checks can be done.
688 } else if (CalleeKind::Constrain == tgtKind) {
689 // Without known target, this can only check that there are no affine map instantiations.
690 return verifyNoAffineMapInstantiations();
691 }
692 return success();
693 }
694
695private:
696 SymbolRefAttr calleeAttr;
697};
698
699} // namespace
700
701LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
702 // First, verify symbol resolution in all input and output types.
703 if (failed(verifyTypeResolution(tables, *this, getCalleeType()))) {
704 return failure(); // verifyTypeResolution() already emits a sufficient error message
705 }
706
707 // Check that the callee attribute was specified.
708 SymbolRefAttr calleeAttr = getCalleeAttr();
709 if (!calleeAttr) {
710 return emitOpError("requires a 'callee' symbol reference attribute");
711 }
712
713 // If the callee references a parameter of the struct where this call appears, perform the subset
714 // of checks that can be done even though the target is unknown.
715 if (calleeAttr.getNestedReferences().size() == 1) {
716 FailureOr<StructDefOp> parent = getParentOfType<StructDefOp>(*this);
717 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
718 return UnknownTargetVerifier(this, calleeAttr).verify();
719 }
720 }
721
722 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
723 // checks against the known target function.
724 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
725 if (failed(tgtOpt)) {
726 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
727 << calleeAttr << '"';
728 }
729 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
730}
731
732FunctionType CallOp::getCalleeType() {
733 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
734}
735
736namespace {
737
738bool calleeIsStructFunctionImpl(
739 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
740) {
741 if (callee.getLeafReference() == funcName) {
742 if (StructType t = getType()) {
743 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
744 // name itself), then the `callee` target must be within a StructDefOp because validation
745 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
746 // function is not simply a global function nested within a ModuleOp)
747 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
748 }
749 }
750 return false;
751}
752
753} // namespace
754
756 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
757 return this->getSingleResultTypeOfCompute();
758 });
759}
760
762 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
763 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
764 });
765}
766
768 assert(calleeIsStructCompute());
769 return getResults().front();
770}
771
773 assert(calleeIsStructConstrain());
774 return getArgOperands().front();
775}
776
777FailureOr<SymbolLookupResult<FuncDefOp>> CallOp::getCalleeTarget(SymbolTableCollection &tables) {
778 Operation *thisOp = this->getOperation();
779 auto root = getRootModule(thisOp);
780 assert(succeeded(root));
781 return llzk::lookupSymbolIn<FuncDefOp>(tables, getCallee(), root->getOperation(), thisOp);
782}
783
785 assert(calleeIsCompute() && "violated implementation pre-condition");
786 return getIfSingleton<StructType>(getResultTypes());
787}
788
790CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
791
793void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
794 setCalleeAttr(llvm::cast<SymbolRefAttr>(callee));
795}
796
797SmallVector<ValueRange> CallOp::toVectorOfValueRange(OperandRangeRange input) {
798 llvm::SmallVector<ValueRange, 4> output;
799 output.reserve(input.size());
800 for (OperandRange r : input) {
801 output.push_back(r);
802 }
803 return output;
804}
805
806} // namespace llzk::function
MlirStringRef name
Definition Poly.cpp:48
This file defines methods symbol lookup across LLZK operations and included files.
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:761
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:790
::mlir::FunctionType getCalleeType()
Definition Ops.cpp:732
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:784
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:701
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:755
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:329
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={})
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:472
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:772
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:797
FoldAdaptor::Properties Properties
Definition Ops.h.inc:192
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:282
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:767
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:793
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:777
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:195
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:306
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:333
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:547
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:352
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:106
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:758
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:717
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:357
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:555
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:95
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:115
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:766
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:762
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:772
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:769
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Returns the result types of this function.
Definition Ops.h.inc:740
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:571
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:187
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:323
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:203
::llvm::LogicalResult verify()
Definition Ops.cpp:213
::mlir::Region & getBody()
Definition Ops.h.inc:595
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:709
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:539
::llvm::LogicalResult verify()
Definition Ops.cpp:366
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:873
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:41
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:105
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:255
FailureOr< ModuleOp > getRootModule(Operation *from)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:29
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:259
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)