LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
22
23#include <mlir/IR/IRMapping.h>
24#include <mlir/IR/OpImplementation.h>
25#include <mlir/Interfaces/FunctionImplementation.h>
26
27#include <llvm/ADT/MapVector.h>
28
29// TableGen'd implementation files
30#define GET_OP_CLASSES
32
33using namespace mlir;
34using namespace llzk::component;
35using namespace llzk::polymorphic;
36
37namespace llzk::function {
38
39namespace {
41inline LogicalResult
42verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
44 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
45 );
46}
47} // namespace
48
49//===----------------------------------------------------------------------===//
50// FuncDefOp
51//===----------------------------------------------------------------------===//
52
54 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
55) {
56 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
57}
58
60 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
61) {
62 SmallVector<NamedAttribute, 8> attrRef(attrs);
63 return create(location, name, type, llvm::ArrayRef(attrRef));
64}
65
67 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
68 ArrayRef<DictionaryAttr> argAttrs
69) {
70 FuncDefOp func = create(location, name, type, attrs);
71 func.setAllArgAttrs(argAttrs);
72 return func;
73}
74
76 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
77 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
78) {
79 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
80 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
81 state.attributes.append(attrs.begin(), attrs.end());
82 state.addRegion();
83
84 if (argAttrs.empty()) {
85 return;
86 }
87 assert(type.getNumInputs() == argAttrs.size());
88 function_interface_impl::addArgAndResultAttrs(
89 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
90 getResAttrsAttrName(state.name)
91 );
92}
93
94ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
95 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
96 function_interface_impl::VariadicFlag,
97 std::string &) { return builder.getFunctionType(argTypes, results); };
98
99 return function_interface_impl::parseFunctionOp(
100 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
101 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
102 );
103}
104
105void FuncDefOp::print(OpAsmPrinter &p) {
106 function_interface_impl::printFunctionOp(
107 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
109 );
110}
111
114void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
115 // Add the attributes of this function to dest.
116 llvm::MapVector<StringAttr, Attribute> newAttrMap;
117 for (const auto &attr : dest->getAttrs()) {
118 newAttrMap.insert({attr.getName(), attr.getValue()});
119 }
120 for (const auto &attr : (*this)->getAttrs()) {
121 newAttrMap.insert({attr.getName(), attr.getValue()});
122 }
123
124 auto newAttrs =
125 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
126 return NamedAttribute(attrPair.first, attrPair.second);
127 }));
128 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
129
130 // Clone the body.
131 getBody().cloneInto(&dest.getBody(), mapper);
132}
133
139FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
140 // Create the new function.
141 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
142
143 // If the function has a body, then the user might be deleting arguments to
144 // the function by specifying them in the mapper. If so, we don't add the
145 // argument to the input type vector.
146 if (!isExternal()) {
147 FunctionType oldType = getFunctionType();
148
149 unsigned oldNumArgs = oldType.getNumInputs();
150 SmallVector<Type, 4> newInputs;
151 newInputs.reserve(oldNumArgs);
152 for (unsigned i = 0; i != oldNumArgs; ++i) {
153 if (!mapper.contains(getArgument(i))) {
154 newInputs.push_back(oldType.getInput(i));
155 }
156 }
157
160 if (newInputs.size() != oldNumArgs) {
161 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
162
163 if (ArrayAttr argAttrs = getAllArgAttrs()) {
164 SmallVector<Attribute> newArgAttrs;
165 newArgAttrs.reserve(newInputs.size());
166 for (unsigned i = 0; i != oldNumArgs; ++i) {
167 if (!mapper.contains(getArgument(i))) {
168 newArgAttrs.push_back(argAttrs[i]);
169 }
170 }
171 newFunc.setAllArgAttrs(newArgAttrs);
172 }
173 }
174 }
175
177 cloneInto(newFunc, mapper);
178 return newFunc;
179}
180
182 IRMapping mapper;
183 return clone(mapper);
184}
185
187 if (newValue) {
188 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
189 } else {
190 getOperation()->removeAttr(AllowConstraintAttr::name);
191 }
192}
193
195 if (newValue) {
196 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
197 } else {
198 getOperation()->removeAttr(AllowWitnessAttr::name);
199 }
200}
201
202bool FuncDefOp::hasArgPublicAttr(unsigned index) {
203 if (index < this->getNumArguments()) {
204 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
205 return res ? res.contains(PublicAttr::name) : false;
206 } else {
207 // TODO: print error? requested attribute for non-existant argument index
208 return false;
209 }
210}
211
212LogicalResult FuncDefOp::verify() {
213 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
214 // Ensure that only valid LLZK types are used for arguments and return.
215 // @compute and @constrain functions also may not have AffineMapAttrs in their
216 // parameters.
217 FunctionType type = getFunctionType();
218 llvm::ArrayRef<Type> inTypes = type.getInputs();
219 for (auto ptr = inTypes.begin(); ptr < inTypes.end(); ptr++) {
220 if (llzk::checkValidType(emitErrorFunc, *ptr).failed()) {
221 return failure();
222 }
223 if (isInStruct() && (nameIsCompute() || nameIsConstrain()) && hasAffineMapAttr(*ptr)) {
224 emitErrorFunc().append(
225 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", *ptr
226 );
227 return failure();
228 }
229 }
230 llvm::ArrayRef<Type> resTypes = type.getResults();
231 for (auto ptr = resTypes.begin(); ptr < resTypes.end(); ptr++) {
232 if (llzk::checkValidType(emitErrorFunc, *ptr).failed()) {
233 return failure();
234 }
235 }
236 return success();
237}
238
239namespace {
240
241LogicalResult
242verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
243 FunctionType funcType = origin.getFunctionType();
244 llvm::ArrayRef<Type> resTypes = funcType.getResults();
245 // Must return type of parent struct
246 if (resTypes.size() != 1) {
247 return origin.emitOpError().append(
248 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
249 );
250 }
251 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
252 return failure();
253 }
254
255 // After the more specific checks (to ensure more specific error messages would be produced if
256 // necessary), do the general check that all symbol references in the types are valid. The return
257 // types were already checked so just check the input types.
258 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
259}
260
261LogicalResult
262verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
263 FunctionType funcType = origin.getFunctionType();
264 // Must return '()' type, i.e., have no return types
265 if (funcType.getResults().size() != 0) {
266 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
267 }
268
269 // Type of the first parameter must match the parent StructDefOp of the current operation.
270 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
271 if (inputTypes.size() < 1) {
272 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
273 << "\" must have at least one input type";
274 }
275 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
276 return failure();
277 }
278
279 // After the more specific checks (to ensure more specific error messages would be produced if
280 // necessary), do the general check that all symbol references in the types are valid. There are
281 // no return types, just check the remaining input types (the first was already checked via
282 // the checkSelfType() call above).
283 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
284}
285
286} // namespace
287
288LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
289 // Additional checks for the compute/constrain functions w/in a struct
290 FailureOr<StructDefOp> parentStructOpt = getParentOfType<StructDefOp>(*this);
291 if (succeeded(parentStructOpt)) {
292 // Verify return type restrictions for functions within a StructDefOp
293 if (nameIsCompute()) {
294 return verifyFuncTypeCompute(*this, tables, parentStructOpt.value());
295 } else if (nameIsConstrain()) {
296 return verifyFuncTypeConstrain(*this, tables, parentStructOpt.value());
297 }
298 }
299 // In the general case, verify symbol resolution in all input and output types.
300 return verifyTypeResolution(tables, *this, getFunctionType());
301}
302
304 auto res = getPathFromRoot(*this);
305 assert(succeeded(res));
306 return res.value();
307}
308
310 assert(isStructCompute() && "violated implementation pre-condition");
312}
313
314//===----------------------------------------------------------------------===//
315// ReturnOp
316//===----------------------------------------------------------------------===//
317
318LogicalResult ReturnOp::verify() {
319 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
320
321 // The operand number and types must match the function signature.
322 const auto results = function.getFunctionType().getResults();
323 if (getNumOperands() != results.size()) {
324 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
325 << function.getName() << ") returns " << results.size();
326 }
327
328 for (unsigned i = 0, e = results.size(); i != e; ++i) {
329 if (!typesUnify(getOperand(i).getType(), results[i])) {
330 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
331 << ") doesn't match function result type (" << results[i] << ")"
332 << " in function @" << function.getName();
333 }
334 }
335
336 return success();
337}
338
339//===----------------------------------------------------------------------===//
340// CallOp
341//===----------------------------------------------------------------------===//
342
343void CallOp::build(
344 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
345 ValueRange argOperands
346) {
347 odsState.addTypes(resultTypes);
348 odsState.addOperands(argOperands);
350 odsBuilder, odsState, static_cast<int32_t>(argOperands.size())
351 );
352 props.setCallee(callee);
353}
354
355void CallOp::build(
356 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
357 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
358) {
359 odsState.addTypes(resultTypes);
360 odsState.addOperands(argOperands);
362 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
363 );
364 props.setCallee(callee);
365}
366
367namespace {
368enum class CalleeKind { Compute, Constrain, Other };
369
370CalleeKind calleeNameToKind(StringRef tgtName) {
371 if (FUNC_NAME_COMPUTE == tgtName) {
372 return CalleeKind::Compute;
373 } else if (FUNC_NAME_CONSTRAIN == tgtName) {
374 return CalleeKind::Constrain;
375 } else {
376 return CalleeKind::Other;
377 }
378}
379
380struct CallOpVerifier {
381 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(calleeNameToKind(tgtName)) {}
382 virtual ~CallOpVerifier() = default;
383
384 LogicalResult verify() {
385 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
386 // provide as many errors are possible in a single verifier run.
387 LogicalResult aggregateResult = success();
388 if (failed(verifyTargetAttributes())) {
389 aggregateResult = failure();
390 }
391 if (failed(verifyInputs())) {
392 aggregateResult = failure();
393 }
394 if (failed(verifyOutputs())) {
395 aggregateResult = failure();
396 }
397 if (failed(verifyAffineMapParams())) {
398 aggregateResult = failure();
399 }
400 return aggregateResult;
401 }
402
403protected:
404 CallOp *callOp;
405 CalleeKind tgtKind;
406
407 virtual LogicalResult verifyTargetAttributes() = 0;
408 virtual LogicalResult verifyInputs() = 0;
409 virtual LogicalResult verifyOutputs() = 0;
410 virtual LogicalResult verifyAffineMapParams() = 0;
411
413 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
414 LogicalResult aggregateRes = success();
415 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
416 auto emitAttrErr = [&](StringLiteral attrName) {
417 aggregateRes = callOp->emitOpError()
418 << "target '@" << target.getName() << "' has '" << attrName
419 << "' attribute, which is not specified by the caller '@" << caller.getName()
420 << '\'';
421 };
422
423 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
424 emitAttrErr(AllowConstraintAttr::name);
425 }
426 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
427 emitAttrErr(AllowWitnessAttr::name);
428 }
429 }
430 return aggregateRes;
431 }
432
433 LogicalResult verifyNoAffineMapInstantiations() {
434 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
435 // Tested in call_with_affinemap_fail.llzk
436 return callOp->emitOpError().append(
437 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
438 "\" function"
439 );
440 }
441 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
442 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
443 assert(callOp->getMapOperands().empty());
444 return success();
445 }
446};
447
448struct KnownTargetVerifier : public CallOpVerifier {
449 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
450 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
451 includeSymNames(tgtRes.getIncludeSymNames()) {}
452
453 LogicalResult verifyTargetAttributes() override {
454 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
455 }
456
457 LogicalResult verifyInputs() override {
458 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
459 }
460
461 LogicalResult verifyOutputs() override {
462 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
463 }
464
465 LogicalResult verifyAffineMapParams() override {
466 if (CalleeKind::Compute == tgtKind && isInStruct(tgt.getOperation())) {
467 // Return type should be a single StructType. If that is not the case here, just bail without
468 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
469 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
470 // a single StructType return value will produce a more relevant error message in that case.
471 if (StructType retTy = callOp->getSingleResultTypeOfCompute()) {
472 if (ArrayAttr params = retTy.getParams()) {
473 // Collect the struct parameters that are defined via AffineMapAttr
474 SmallVector<AffineMapAttr> mapAttrs;
475 for (Attribute a : params) {
476 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
477 mapAttrs.push_back(m);
478 }
479 }
481 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
482 );
483 }
484 }
485 return success();
486 } else {
487 // Global functions and constrain functions cannot have affine map instantiations.
488 return verifyNoAffineMapInstantiations();
489 }
490 }
491
492private:
493 template <typename T>
494 LogicalResult
495 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
496 if (tgtTypes.size() != callOpTypes.size()) {
497 return callOp->emitOpError()
498 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
499 .attachNote(tgt.getLoc())
500 .append("callee defined here");
501 }
502 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
503 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
504 return callOp->emitOpError().append(
505 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
506 " for ", aspect, " number ", i
507 );
508 }
509 }
510 return success();
511 }
512
513 FuncDefOp tgt;
514 FunctionType tgtType;
515 std::vector<llvm::StringRef> includeSymNames;
516};
517
520LogicalResult checkSelfTypeUnknownTarget(
521 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
522) {
523 if (!llvm::isa<TypeVarType>(actualType) ||
524 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
525 // Tested in function_restrictions_fail.llzk:
526 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
527 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
528 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
529 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
530 return origin->emitOpError().append(
531 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
532 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
533 actualType
534 );
535 }
536 return success();
537}
538
548struct UnknownTargetVerifier : public CallOpVerifier {
549 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
550 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
551
552 LogicalResult verifyTargetAttributes() override {
553 // Based on the precondition of this verifier, the target must be either a
554 // struct compute or constrain function.
555 LogicalResult aggregateRes = success();
556 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
557 auto emitAttrErr = [&](StringLiteral attrName) {
558 aggregateRes = callOp->emitOpError()
559 << "target '" << calleeAttr << "' has '" << attrName
560 << "' attribute, which is not specified by the caller '@" << caller.getName()
561 << '\'';
562 };
563
564 if (tgtKind == CalleeKind::Constrain && !caller.hasAllowConstraintAttr()) {
565 emitAttrErr(AllowConstraintAttr::name);
566 }
567 if (tgtKind == CalleeKind::Compute && !caller.hasAllowWitnessAttr()) {
568 emitAttrErr(AllowWitnessAttr::name);
569 }
570 }
571 return aggregateRes;
572 }
573
574 LogicalResult verifyInputs() override {
575 if (CalleeKind::Compute == tgtKind) {
576 // Without known target, no additional checks can be done.
577 } else if (CalleeKind::Constrain == tgtKind) {
578 // Without known target, this can only check that the first input is VarType using the same
579 // struct parameter as the base of the callee (later replaced with the target struct's type).
580 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
581 if (inputTypes.size() < 1) {
582 // Tested in function_restrictions_fail.llzk
583 return callOp->emitOpError()
584 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
585 }
586 return checkSelfTypeUnknownTarget(
587 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
588 );
589 }
590 return success();
591 }
592
593 LogicalResult verifyOutputs() override {
594 if (CalleeKind::Compute == tgtKind) {
595 // Without known target, this can only check that the function returns VarType using the same
596 // struct parameter as the base of the callee (later replaced with the target struct's type).
597 Operation::result_type_range resTypes = callOp->getResultTypes();
598 if (resTypes.size() != 1) {
599 // Tested in function_restrictions_fail.llzk
600 return callOp->emitOpError().append(
601 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
602 );
603 }
604 return checkSelfTypeUnknownTarget(
605 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
606 );
607 } else if (CalleeKind::Constrain == tgtKind) {
608 // Without known target, this can only check that the function has no return
609 if (callOp->getNumResults() != 0) {
610 // Tested in function_restrictions_fail.llzk
611 return callOp->emitOpError()
612 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
613 }
614 }
615 return success();
616 }
617
618 LogicalResult verifyAffineMapParams() override {
619 if (CalleeKind::Compute == tgtKind) {
620 // Without known target, no additional checks can be done.
621 } else if (CalleeKind::Constrain == tgtKind) {
622 // Without known target, this can only check that there are no affine map instantiations.
623 return verifyNoAffineMapInstantiations();
624 }
625 return success();
626 }
627
628private:
629 SymbolRefAttr calleeAttr;
630};
631
632} // namespace
633
634LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
635 // First, verify symbol resolution in all input and output types.
636 if (failed(verifyTypeResolution(tables, *this, getCalleeType()))) {
637 return failure(); // verifyTypeResolution() already emits a sufficient error message
638 }
639
640 // Check that the callee attribute was specified.
641 SymbolRefAttr calleeAttr = getCalleeAttr();
642 if (!calleeAttr) {
643 return emitOpError("requires a 'callee' symbol reference attribute");
644 }
645
646 // If the callee references a parameter of the struct where this call appears, perform the subset
647 // of checks that can be done even though the target is unknown.
648 if (calleeAttr.getNestedReferences().size() == 1) {
649 FailureOr<StructDefOp> parent = getParentOfType<StructDefOp>(*this);
650 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
651 return UnknownTargetVerifier(this, calleeAttr).verify();
652 }
653 }
654
655 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
656 // checks against the known target function.
657 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
658 if (failed(tgtOpt)) {
659 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
660 << calleeAttr << '"';
661 }
662 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
663}
664
665FunctionType CallOp::getCalleeType() {
666 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
667}
668
669namespace {
670
671bool calleeIsStructFunctionImpl(
672 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
673) {
674 if (callee.getLeafReference() == funcName) {
675 if (StructType t = getType()) {
676 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
677 // name itself), then the `callee` target must be within a StructDefOp because validation
678 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
679 // function is not simply a global function nested within a ModuleOp)
680 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
681 }
682 }
683 return false;
684}
685
686} // namespace
687
689 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
690 return this->getSingleResultTypeOfCompute();
691 });
692}
693
695 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
696 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
697 });
698}
699
701 assert(calleeIsCompute() && "violated implementation pre-condition");
702 return getIfSingleton<StructType>(getResultTypes());
703}
704
706CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
707
709void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
710 setCalleeAttr(callee.get<SymbolRefAttr>());
711}
712
713} // namespace llzk::function
MlirStringRef name
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:694
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:706
::mlir::FunctionType getCalleeType()
Definition Ops.cpp:665
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:700
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:634
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:688
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:286
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={})
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:535
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.cpp.inc:522
FoldAdaptor::Properties Properties
Definition Ops.h.inc:184
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.cpp.inc:549
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:709
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:194
::mlir::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:288
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:454
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:105
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:612
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:571
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:309
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:462
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:94
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:114
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:616
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:622
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:619
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Returns the result types of this function.
Definition Ops.h.inc:594
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:478
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:186
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:202
::mlir::LogicalResult verify()
Definition Ops.cpp:212
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:303
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:563
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:446
::mlir::LogicalResult verify()
Definition Ops.cpp:318
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:38
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp &expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:92
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:251
bool isNullOrEmpty(mlir::ArrayAttr a)
std::function< mlir::InFlightDiagnostic()> OwningEmitErrorFn
Definition ErrorHelper.h:25
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:255
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
Definition ErrorHelper.h:27
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)