LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
23
24#include <mlir/IR/IRMapping.h>
25#include <mlir/IR/OpImplementation.h>
26#include <mlir/Interfaces/FunctionImplementation.h>
27
28#include <llvm/ADT/MapVector.h>
29
30// TableGen'd implementation files
31#define GET_OP_CLASSES
33
34using namespace mlir;
35using namespace llzk::component;
36using namespace llzk::polymorphic;
37
38namespace llzk::function {
39
40namespace {
42inline LogicalResult
43verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
45 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
46 );
47}
48} // namespace
49
50//===----------------------------------------------------------------------===//
51// FuncDefOp
52//===----------------------------------------------------------------------===//
53
55 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
56) {
57 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
58}
59
61 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
62) {
63 SmallVector<NamedAttribute, 8> attrRef(attrs);
64 return create(location, name, type, llvm::ArrayRef(attrRef));
65}
66
68 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
69 ArrayRef<DictionaryAttr> argAttrs
70) {
71 FuncDefOp func = create(location, name, type, attrs);
72 func.setAllArgAttrs(argAttrs);
73 return func;
74}
75
77 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
78 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
79) {
80 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
81 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
82 state.attributes.append(attrs.begin(), attrs.end());
83 state.addRegion();
84
85 if (argAttrs.empty()) {
86 return;
87 }
88 assert(type.getNumInputs() == argAttrs.size());
89 function_interface_impl::addArgAndResultAttrs(
90 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
91 getResAttrsAttrName(state.name)
92 );
93}
94
95ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
96 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
97 function_interface_impl::VariadicFlag,
98 std::string &) { return builder.getFunctionType(argTypes, results); };
99
100 return function_interface_impl::parseFunctionOp(
101 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
102 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
103 );
104}
105
106void FuncDefOp::print(OpAsmPrinter &p) {
107 function_interface_impl::printFunctionOp(
108 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
110 );
111}
112
115void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
116 // Add the attributes of this function to dest.
117 llvm::MapVector<StringAttr, Attribute> newAttrMap;
118 for (const auto &attr : dest->getAttrs()) {
119 newAttrMap.insert({attr.getName(), attr.getValue()});
120 }
121 for (const auto &attr : (*this)->getAttrs()) {
122 newAttrMap.insert({attr.getName(), attr.getValue()});
123 }
124
125 auto newAttrs =
126 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
127 return NamedAttribute(attrPair.first, attrPair.second);
128 }));
129 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
130
131 // Clone the body.
132 getBody().cloneInto(&dest.getBody(), mapper);
133}
134
140FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
141 // Create the new function.
142 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
143
144 // If the function has a body, then the user might be deleting arguments to
145 // the function by specifying them in the mapper. If so, we don't add the
146 // argument to the input type vector.
147 if (!isExternal()) {
148 FunctionType oldType = getFunctionType();
149
150 unsigned oldNumArgs = oldType.getNumInputs();
151 SmallVector<Type, 4> newInputs;
152 newInputs.reserve(oldNumArgs);
153 for (unsigned i = 0; i != oldNumArgs; ++i) {
154 if (!mapper.contains(getArgument(i))) {
155 newInputs.push_back(oldType.getInput(i));
156 }
157 }
158
161 if (newInputs.size() != oldNumArgs) {
162 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
163
164 if (ArrayAttr argAttrs = getAllArgAttrs()) {
165 SmallVector<Attribute> newArgAttrs;
166 newArgAttrs.reserve(newInputs.size());
167 for (unsigned i = 0; i != oldNumArgs; ++i) {
168 if (!mapper.contains(getArgument(i))) {
169 newArgAttrs.push_back(argAttrs[i]);
170 }
171 }
172 newFunc.setAllArgAttrs(newArgAttrs);
173 }
174 }
175 }
176
178 cloneInto(newFunc, mapper);
179 return newFunc;
180}
181
183 IRMapping mapper;
184 return clone(mapper);
185}
186
188 if (newValue) {
189 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
190 } else {
191 getOperation()->removeAttr(AllowConstraintAttr::name);
192 }
193}
194
196 if (newValue) {
197 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
198 } else {
199 getOperation()->removeAttr(AllowWitnessAttr::name);
200 }
201}
202
203bool FuncDefOp::hasArgPublicAttr(unsigned index) {
204 if (index < this->getNumArguments()) {
205 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
206 return res ? res.contains(PublicAttr::name) : false;
207 } else {
208 // TODO: print error? requested attribute for non-existant argument index
209 return false;
210 }
211}
212
213LogicalResult FuncDefOp::verify() {
214 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
215 // Ensure that only valid LLZK types are used for arguments and return.
216 // @compute and @constrain functions also may not have AffineMapAttrs in their
217 // parameters.
218 FunctionType type = getFunctionType();
219 llvm::ArrayRef<Type> inTypes = type.getInputs();
220 for (auto ptr = inTypes.begin(); ptr < inTypes.end(); ptr++) {
221 if (llzk::checkValidType(emitErrorFunc, *ptr).failed()) {
222 return failure();
223 }
224 if (isInStruct() && (nameIsCompute() || nameIsConstrain()) && hasAffineMapAttr(*ptr)) {
225 emitErrorFunc().append(
226 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", *ptr
227 );
228 return failure();
229 }
230 }
231 llvm::ArrayRef<Type> resTypes = type.getResults();
232 for (auto ptr = resTypes.begin(); ptr < resTypes.end(); ptr++) {
233 if (llzk::checkValidType(emitErrorFunc, *ptr).failed()) {
234 return failure();
235 }
236 }
237 return success();
238}
239
240namespace {
241
242LogicalResult
243verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
244 FunctionType funcType = origin.getFunctionType();
245 llvm::ArrayRef<Type> resTypes = funcType.getResults();
246 // Must return type of parent struct
247 if (resTypes.size() != 1) {
248 return origin.emitOpError().append(
249 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
250 );
251 }
252 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
253 return failure();
254 }
255
256 // After the more specific checks (to ensure more specific error messages would be produced if
257 // necessary), do the general check that all symbol references in the types are valid. The return
258 // types were already checked so just check the input types.
259 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
260}
261
262LogicalResult
263verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
264 FunctionType funcType = origin.getFunctionType();
265 // Must return '()' type, i.e., have no return types
266 if (funcType.getResults().size() != 0) {
267 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
268 }
269
270 // Type of the first parameter must match the parent StructDefOp of the current operation.
271 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
272 if (inputTypes.size() < 1) {
273 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
274 << "\" must have at least one input type";
275 }
276 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
277 return failure();
278 }
279
280 // After the more specific checks (to ensure more specific error messages would be produced if
281 // necessary), do the general check that all symbol references in the types are valid. There are
282 // no return types, just check the remaining input types (the first was already checked via
283 // the checkSelfType() call above).
284 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
285}
286
287} // namespace
288
289LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
290 // Additional checks for the compute/constrain functions w/in a struct
291 FailureOr<StructDefOp> parentStructOpt = getParentOfType<StructDefOp>(*this);
292 if (succeeded(parentStructOpt)) {
293 // Verify return type restrictions for functions within a StructDefOp
294 if (nameIsCompute()) {
295 return verifyFuncTypeCompute(*this, tables, parentStructOpt.value());
296 } else if (nameIsConstrain()) {
297 return verifyFuncTypeConstrain(*this, tables, parentStructOpt.value());
298 }
299 }
300 // In the general case, verify symbol resolution in all input and output types.
301 return verifyTypeResolution(tables, *this, getFunctionType());
302}
303
304SymbolRefAttr FuncDefOp::getFullyQualifiedName(bool requireParent) {
305 // If the parent is not present and not required, just return the symbol name
306 if (!requireParent && getOperation()->getParentOp() == nullptr) {
307 return SymbolRefAttr::get(getOperation());
308 }
309 auto res = getPathFromRoot(*this);
310 assert(succeeded(res));
311 return res.value();
312}
313
315 assert(nameIsCompute()); // skip inStruct check to allow dangling functions
316 // Get the single block of the function body
317 Region &body = getBody();
318 assert(!body.empty() && "compute() function body is empty");
319 Block &block = body.back();
320
321 // The terminator should be the return op
322 Operation *terminator = block.getTerminator();
323 assert(terminator && "compute() function has no terminator");
324 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
325 if (!retOp) {
326 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
327 << terminator->getName() << "'\n";
328 llvm_unreachable("compute() function must end with ReturnOp");
329 }
330 return retOp.getOperands().front();
331}
332
334 assert(nameIsConstrain()); // skip inStruct check to allow dangling functions
335 return getArguments().front();
336}
337
339 assert(isStructCompute() && "violated implementation pre-condition");
341}
342
343//===----------------------------------------------------------------------===//
344// ReturnOp
345//===----------------------------------------------------------------------===//
346
347LogicalResult ReturnOp::verify() {
348 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
349
350 // The operand number and types must match the function signature.
351 const auto results = function.getFunctionType().getResults();
352 if (getNumOperands() != results.size()) {
353 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
354 << function.getName() << ") returns " << results.size();
355 }
356
357 for (unsigned i = 0, e = results.size(); i != e; ++i) {
358 if (!typesUnify(getOperand(i).getType(), results[i])) {
359 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
360 << ") doesn't match function result type (" << results[i] << ")"
361 << " in function @" << function.getName();
362 }
363 }
364
365 return success();
366}
367
368//===----------------------------------------------------------------------===//
369// CallOp
370//===----------------------------------------------------------------------===//
371
372void CallOp::build(
373 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
374 ValueRange argOperands
375) {
376 odsState.addTypes(resultTypes);
377 odsState.addOperands(argOperands);
379 odsBuilder, odsState, static_cast<int32_t>(argOperands.size())
380 );
381 props.setCallee(callee);
382}
383
384void CallOp::build(
385 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
386 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
387) {
388 odsState.addTypes(resultTypes);
389 odsState.addOperands(argOperands);
391 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
392 );
393 props.setCallee(callee);
394}
395
396namespace {
397enum class CalleeKind { Compute, Constrain, Other };
398
399CalleeKind calleeNameToKind(StringRef tgtName) {
400 if (FUNC_NAME_COMPUTE == tgtName) {
401 return CalleeKind::Compute;
402 } else if (FUNC_NAME_CONSTRAIN == tgtName) {
403 return CalleeKind::Constrain;
404 } else {
405 return CalleeKind::Other;
406 }
407}
408
409struct CallOpVerifier {
410 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(calleeNameToKind(tgtName)) {}
411 virtual ~CallOpVerifier() = default;
412
413 LogicalResult verify() {
414 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
415 // provide as many errors are possible in a single verifier run.
416 LogicalResult aggregateResult = success();
417 if (failed(verifyTargetAttributes())) {
418 aggregateResult = failure();
419 }
420 if (failed(verifyInputs())) {
421 aggregateResult = failure();
422 }
423 if (failed(verifyOutputs())) {
424 aggregateResult = failure();
425 }
426 if (failed(verifyAffineMapParams())) {
427 aggregateResult = failure();
428 }
429 return aggregateResult;
430 }
431
432protected:
433 CallOp *callOp;
434 CalleeKind tgtKind;
435
436 virtual LogicalResult verifyTargetAttributes() = 0;
437 virtual LogicalResult verifyInputs() = 0;
438 virtual LogicalResult verifyOutputs() = 0;
439 virtual LogicalResult verifyAffineMapParams() = 0;
440
442 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
443 LogicalResult aggregateRes = success();
444 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
445 auto emitAttrErr = [&](StringLiteral attrName) {
446 aggregateRes = callOp->emitOpError()
447 << "target '@" << target.getName() << "' has '" << attrName
448 << "' attribute, which is not specified by the caller '@" << caller.getName()
449 << '\'';
450 };
451
452 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
453 emitAttrErr(AllowConstraintAttr::name);
454 }
455 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
456 emitAttrErr(AllowWitnessAttr::name);
457 }
458 }
459 return aggregateRes;
460 }
461
462 LogicalResult verifyNoAffineMapInstantiations() {
463 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
464 // Tested in call_with_affinemap_fail.llzk
465 return callOp->emitOpError().append(
466 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
467 "\" function"
468 );
469 }
470 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
471 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
472 assert(callOp->getMapOperands().empty());
473 return success();
474 }
475};
476
477struct KnownTargetVerifier : public CallOpVerifier {
478 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
479 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
480 includeSymNames(tgtRes.getIncludeSymNames()) {}
481
482 LogicalResult verifyTargetAttributes() override {
483 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
484 }
485
486 LogicalResult verifyInputs() override {
487 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
488 }
489
490 LogicalResult verifyOutputs() override {
491 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
492 }
493
494 LogicalResult verifyAffineMapParams() override {
495 if (CalleeKind::Compute == tgtKind && isInStruct(tgt.getOperation())) {
496 // Return type should be a single StructType. If that is not the case here, just bail without
497 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
498 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
499 // a single StructType return value will produce a more relevant error message in that case.
500 if (StructType retTy = callOp->getSingleResultTypeOfCompute()) {
501 if (ArrayAttr params = retTy.getParams()) {
502 // Collect the struct parameters that are defined via AffineMapAttr
503 SmallVector<AffineMapAttr> mapAttrs;
504 for (Attribute a : params) {
505 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
506 mapAttrs.push_back(m);
507 }
508 }
510 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
511 );
512 }
513 }
514 return success();
515 } else {
516 // Global functions and constrain functions cannot have affine map instantiations.
517 return verifyNoAffineMapInstantiations();
518 }
519 }
520
521private:
522 template <typename T>
523 LogicalResult
524 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
525 if (tgtTypes.size() != callOpTypes.size()) {
526 return callOp->emitOpError()
527 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
528 .attachNote(tgt.getLoc())
529 .append("callee defined here");
530 }
531 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
532 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
533 return callOp->emitOpError().append(
534 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
535 " for ", aspect, " number ", i
536 );
537 }
538 }
539 return success();
540 }
541
542 FuncDefOp tgt;
543 FunctionType tgtType;
544 std::vector<llvm::StringRef> includeSymNames;
545};
546
549LogicalResult checkSelfTypeUnknownTarget(
550 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
551) {
552 if (!llvm::isa<TypeVarType>(actualType) ||
553 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
554 // Tested in function_restrictions_fail.llzk:
555 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
556 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
557 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
558 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
559 return origin->emitOpError().append(
560 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
561 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
562 actualType
563 );
564 }
565 return success();
566}
567
577struct UnknownTargetVerifier : public CallOpVerifier {
578 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
579 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
580
581 LogicalResult verifyTargetAttributes() override {
582 // Based on the precondition of this verifier, the target must be either a
583 // struct compute or constrain function.
584 LogicalResult aggregateRes = success();
585 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
586 auto emitAttrErr = [&](StringLiteral attrName) {
587 aggregateRes = callOp->emitOpError()
588 << "target '" << calleeAttr << "' has '" << attrName
589 << "' attribute, which is not specified by the caller '@" << caller.getName()
590 << '\'';
591 };
592
593 if (tgtKind == CalleeKind::Constrain && !caller.hasAllowConstraintAttr()) {
594 emitAttrErr(AllowConstraintAttr::name);
595 }
596 if (tgtKind == CalleeKind::Compute && !caller.hasAllowWitnessAttr()) {
597 emitAttrErr(AllowWitnessAttr::name);
598 }
599 }
600 return aggregateRes;
601 }
602
603 LogicalResult verifyInputs() override {
604 if (CalleeKind::Compute == tgtKind) {
605 // Without known target, no additional checks can be done.
606 } else if (CalleeKind::Constrain == tgtKind) {
607 // Without known target, this can only check that the first input is VarType using the same
608 // struct parameter as the base of the callee (later replaced with the target struct's type).
609 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
610 if (inputTypes.size() < 1) {
611 // Tested in function_restrictions_fail.llzk
612 return callOp->emitOpError()
613 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
614 }
615 return checkSelfTypeUnknownTarget(
616 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
617 );
618 }
619 return success();
620 }
621
622 LogicalResult verifyOutputs() override {
623 if (CalleeKind::Compute == tgtKind) {
624 // Without known target, this can only check that the function returns VarType using the same
625 // struct parameter as the base of the callee (later replaced with the target struct's type).
626 Operation::result_type_range resTypes = callOp->getResultTypes();
627 if (resTypes.size() != 1) {
628 // Tested in function_restrictions_fail.llzk
629 return callOp->emitOpError().append(
630 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
631 );
632 }
633 return checkSelfTypeUnknownTarget(
634 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
635 );
636 } else if (CalleeKind::Constrain == tgtKind) {
637 // Without known target, this can only check that the function has no return
638 if (callOp->getNumResults() != 0) {
639 // Tested in function_restrictions_fail.llzk
640 return callOp->emitOpError()
641 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
642 }
643 }
644 return success();
645 }
646
647 LogicalResult verifyAffineMapParams() override {
648 if (CalleeKind::Compute == tgtKind) {
649 // Without known target, no additional checks can be done.
650 } else if (CalleeKind::Constrain == tgtKind) {
651 // Without known target, this can only check that there are no affine map instantiations.
652 return verifyNoAffineMapInstantiations();
653 }
654 return success();
655 }
656
657private:
658 SymbolRefAttr calleeAttr;
659};
660
661} // namespace
662
663LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
664 // First, verify symbol resolution in all input and output types.
665 if (failed(verifyTypeResolution(tables, *this, getCalleeType()))) {
666 return failure(); // verifyTypeResolution() already emits a sufficient error message
667 }
668
669 // Check that the callee attribute was specified.
670 SymbolRefAttr calleeAttr = getCalleeAttr();
671 if (!calleeAttr) {
672 return emitOpError("requires a 'callee' symbol reference attribute");
673 }
674
675 // If the callee references a parameter of the struct where this call appears, perform the subset
676 // of checks that can be done even though the target is unknown.
677 if (calleeAttr.getNestedReferences().size() == 1) {
678 FailureOr<StructDefOp> parent = getParentOfType<StructDefOp>(*this);
679 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
680 return UnknownTargetVerifier(this, calleeAttr).verify();
681 }
682 }
683
684 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
685 // checks against the known target function.
686 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
687 if (failed(tgtOpt)) {
688 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
689 << calleeAttr << '"';
690 }
691 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
692}
693
694FunctionType CallOp::getCalleeType() {
695 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
696}
697
698namespace {
699
700bool calleeIsStructFunctionImpl(
701 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
702) {
703 if (callee.getLeafReference() == funcName) {
704 if (StructType t = getType()) {
705 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
706 // name itself), then the `callee` target must be within a StructDefOp because validation
707 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
708 // function is not simply a global function nested within a ModuleOp)
709 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
710 }
711 }
712 return false;
713}
714
715} // namespace
716
718 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
719 return this->getSingleResultTypeOfCompute();
720 });
721}
722
724 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
725 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
726 });
727}
728
730 assert(calleeIsStructCompute());
731 return getResults().front();
732}
733
735 assert(calleeIsStructConstrain());
736 return getArgOperands().front();
737}
738
739FailureOr<SymbolLookupResult<FuncDefOp>> CallOp::getCalleeTarget(SymbolTableCollection &tables) {
740 Operation *thisOp = this->getOperation();
741 auto root = getRootModule(thisOp);
742 assert(succeeded(root));
743 return llzk::lookupSymbolIn<FuncDefOp>(tables, getCallee(), root->getOperation(), thisOp);
744}
745
747 assert(calleeIsCompute() && "violated implementation pre-condition");
748 return getIfSingleton<StructType>(getResultTypes());
749}
750
752CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
753
755void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
756 setCalleeAttr(callee.get<SymbolRefAttr>());
757}
758
759SmallVector<ValueRange> CallOp::toVectorOfValueRange(OperandRangeRange input) {
760 llvm::SmallVector<ValueRange, 4> output;
761 output.reserve(input.size());
762 for (OperandRange r : input) {
763 output.push_back(r);
764 }
765 return output;
766}
767
768} // namespace llzk::function
MlirStringRef name
This file defines methods symbol lookup across LLZK operations and included files.
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:723
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:752
::mlir::FunctionType getCalleeType()
Definition Ops.cpp:694
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:746
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:663
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:717
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:286
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={})
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:535
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:734
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.cpp.inc:522
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:759
FoldAdaptor::Properties Properties
Definition Ops.h.inc:184
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.cpp.inc:549
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:729
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:755
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:739
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:195
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:289
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:314
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:470
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:333
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:106
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:628
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:587
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:338
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:478
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:95
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:115
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:632
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:638
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:635
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Returns the result types of this function.
Definition Ops.h.inc:610
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:494
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:187
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:304
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:203
::mlir::LogicalResult verify()
Definition Ops.cpp:213
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:579
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:462
::mlir::LogicalResult verify()
Definition Ops.cpp:347
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:729
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:38
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp &expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:92
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:255
FailureOr< ModuleOp > getRootModule(Operation *from)
bool isNullOrEmpty(mlir::ArrayAttr a)
std::function< mlir::InFlightDiagnostic()> OwningEmitErrorFn
Definition ErrorHelper.h:25
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:259
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
Definition ErrorHelper.h:27
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)