45 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
55 Location location, StringRef
name, FunctionType type, ArrayRef<NamedAttribute> attrs
61 Location location, StringRef
name, FunctionType type, Operation::dialect_attr_range attrs
63 SmallVector<NamedAttribute, 8> attrRef(attrs);
64 return create(location,
name, type, llvm::ArrayRef(attrRef));
68 Location location, StringRef
name, FunctionType type, ArrayRef<NamedAttribute> attrs,
69 ArrayRef<DictionaryAttr> argAttrs
72 func.setAllArgAttrs(argAttrs);
77 OpBuilder &builder, OperationState &state, StringRef
name, FunctionType type,
78 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
80 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(
name));
82 state.attributes.append(attrs.begin(), attrs.end());
85 if (argAttrs.empty()) {
88 assert(type.getNumInputs() == argAttrs.size());
89 function_interface_impl::addArgAndResultAttrs(
96 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
97 function_interface_impl::VariadicFlag,
98 std::string &) {
return builder.getFunctionType(argTypes, results); };
100 return function_interface_impl::parseFunctionOp(
107 function_interface_impl::printFunctionOp(
117 llvm::MapVector<StringAttr, Attribute> newAttrMap;
118 for (
const auto &attr : dest->getAttrs()) {
119 newAttrMap.insert({attr.getName(), attr.getValue()});
121 for (
const auto &attr : (*this)->getAttrs()) {
122 newAttrMap.insert({attr.getName(), attr.getValue()});
126 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
127 return NamedAttribute(attrPair.first, attrPair.second);
129 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
142 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
150 unsigned oldNumArgs = oldType.getNumInputs();
151 SmallVector<Type, 4> newInputs;
152 newInputs.reserve(oldNumArgs);
153 for (
unsigned i = 0; i != oldNumArgs; ++i) {
154 if (!mapper.contains(getArgument(i))) {
155 newInputs.push_back(oldType.getInput(i));
161 if (newInputs.size() != oldNumArgs) {
162 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
164 if (ArrayAttr argAttrs = getAllArgAttrs()) {
165 SmallVector<Attribute> newArgAttrs;
166 newArgAttrs.reserve(newInputs.size());
167 for (
unsigned i = 0; i != oldNumArgs; ++i) {
168 if (!mapper.contains(getArgument(i))) {
169 newArgAttrs.push_back(argAttrs[i]);
172 newFunc.setAllArgAttrs(newArgAttrs);
184 return clone(mapper);
189 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
191 getOperation()->removeAttr(AllowConstraintAttr::name);
197 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
199 getOperation()->removeAttr(AllowWitnessAttr::name);
204 if (index < this->getNumArguments()) {
205 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
206 return res ? res.contains(PublicAttr::name) :
false;
220 for (Type t : type.getInputs()) {
225 return emitErrorFunc().append(
226 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", t
230 for (Type t : type.getResults()) {
238 WalkResult res = this->walk<WalkOrder::PreOrder>([
this](ModuleOp nestedMod) {
240 "cannot be nested within '", getOperation()->getName(),
"' operations"
242 return WalkResult::interrupt();
244 if (res.wasInterrupted()) {
256 llvm::ArrayRef<Type> resTypes = funcType.getResults();
258 if (resTypes.size() != 1) {
259 return origin.emitOpError().append(
263 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
274verifyFuncTypeProduct(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
276 return verifyFuncTypeCompute(origin, tables, parent);
280verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
283 if (funcType.getResults().size() != 0) {
284 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
288 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
289 if (inputTypes.size() < 1) {
291 <<
"\" must have at least one input type";
293 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
309 if (succeeded(parentStructOpt)) {
312 return verifyFuncTypeCompute(*
this, tables, parentStructOpt.value());
314 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt.value());
316 return verifyFuncTypeProduct(*
this, tables, parentStructOpt.value());
325 if (!requireParent && getOperation()->getParentOp() ==
nullptr) {
326 return SymbolRefAttr::get(getOperation());
329 assert(succeeded(res));
337 assert(!body.empty() &&
"compute() function body is empty");
338 Block &block = body.back();
341 Operation *terminator = block.getTerminator();
342 assert(terminator &&
"compute() function has no terminator");
343 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
346 << terminator->getName() <<
"'\n";
347 llvm_unreachable(
"compute() function must end with ReturnOp");
349 return retOp.getOperands().front();
354 return getArguments().front();
367 auto function = getParentOp<FuncDefOp>();
370 const auto results =
function.getFunctionType().getResults();
371 if (getNumOperands() != results.size()) {
372 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
373 <<
function.getName() <<
") returns " << results.size();
376 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
377 if (!
typesUnify(getOperand(i).getType(), results[i])) {
378 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
379 <<
") doesn't match function result type (" << results[i] <<
")"
380 <<
" in function @" <<
function.getName();
392 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
393 ValueRange argOperands
395 odsState.addTypes(resultTypes);
396 odsState.addOperands(argOperands);
398 odsBuilder, odsState,
static_cast<int32_t
>(argOperands.size())
400 props.setCallee(callee);
404 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
405 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
407 odsState.addTypes(resultTypes);
408 odsState.addOperands(argOperands);
410 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
412 props.setCallee(callee);
416enum class CalleeKind : std::uint8_t { Compute, Constrain, Product, Other };
418CalleeKind calleeNameToKind(StringRef tgtName) {
420 return CalleeKind::Compute;
422 return CalleeKind::Constrain;
424 return CalleeKind::Product;
426 return CalleeKind::Other;
430struct CallOpVerifier {
431 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(calleeNameToKind(tgtName)) {}
432 virtual ~CallOpVerifier() =
default;
434 LogicalResult verify() {
437 LogicalResult aggregateResult = success();
438 if (failed(verifyTargetAttributes())) {
439 aggregateResult = failure();
441 if (failed(verifyInputs())) {
442 aggregateResult = failure();
444 if (failed(verifyOutputs())) {
445 aggregateResult = failure();
447 if (failed(verifyAffineMapParams())) {
448 aggregateResult = failure();
450 return aggregateResult;
457 virtual LogicalResult verifyTargetAttributes() = 0;
458 virtual LogicalResult verifyInputs() = 0;
459 virtual LogicalResult verifyOutputs() = 0;
460 virtual LogicalResult verifyAffineMapParams() = 0;
463 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
464 LogicalResult aggregateRes = success();
465 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
466 auto emitAttrErr = [&](StringLiteral attrName) {
467 aggregateRes = callOp->emitOpError()
468 <<
"target '@" << target.getName() <<
"' has '" << attrName
469 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
474 emitAttrErr(AllowConstraintAttr::name);
477 emitAttrErr(AllowWitnessAttr::name);
483 LogicalResult verifyNoAffineMapInstantiations() {
486 return callOp->emitOpError().append(
487 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
493 assert(callOp->getMapOperands().empty());
498struct KnownTargetVerifier :
public CallOpVerifier {
499 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
500 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
501 includeSymNames(tgtRes.getIncludeSymNames()) {}
503 LogicalResult verifyTargetAttributes()
override {
504 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
507 LogicalResult verifyInputs()
override {
508 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
511 LogicalResult verifyOutputs()
override {
512 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
515 LogicalResult verifyAffineMapParams()
override {
516 if ((CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) &&
523 if (ArrayAttr params = retTy.getParams()) {
525 SmallVector<AffineMapAttr> mapAttrs;
526 for (Attribute a : params) {
527 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
528 mapAttrs.push_back(m);
539 return verifyNoAffineMapInstantiations();
544 template <
typename T>
546 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
547 if (tgtTypes.size() != callOpTypes.size()) {
548 return callOp->emitOpError()
549 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
550 .attachNote(tgt.getLoc())
551 .append(
"callee defined here");
553 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
554 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
555 return callOp->emitOpError().append(
556 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
557 " for ", aspect,
" number ", i
565 FunctionType tgtType;
566 std::vector<llvm::StringRef> includeSymNames;
571LogicalResult checkSelfTypeUnknownTarget(
572 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
574 if (!llvm::isa<TypeVarType>(actualType) ||
575 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
581 return origin->emitOpError().append(
582 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
583 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
599struct UnknownTargetVerifier :
public CallOpVerifier {
600 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
601 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
603 LogicalResult verifyTargetAttributes()
override {
606 LogicalResult aggregateRes = success();
607 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
608 auto emitAttrErr = [&](StringLiteral attrName) {
609 aggregateRes = callOp->emitOpError()
610 <<
"target '" << calleeAttr <<
"' has '" << attrName
611 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
616 case CalleeKind::Constrain:
617 if (!caller.hasAllowConstraintAttr()) {
618 emitAttrErr(AllowConstraintAttr::name);
621 case CalleeKind::Compute:
622 if (!caller.hasAllowWitnessAttr()) {
623 emitAttrErr(AllowWitnessAttr::name);
626 case CalleeKind::Product:
627 if (!caller.hasAllowWitnessAttr()) {
628 emitAttrErr(AllowWitnessAttr::name);
630 if (!caller.hasAllowConstraintAttr()) {
631 emitAttrErr(AllowConstraintAttr::name);
641 LogicalResult verifyInputs()
override {
642 if (CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) {
644 }
else if (CalleeKind::Constrain == tgtKind) {
647 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
648 if (inputTypes.size() < 1) {
650 return callOp->emitOpError()
653 return checkSelfTypeUnknownTarget(
654 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
660 LogicalResult verifyOutputs()
override {
661 if (CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) {
664 Operation::result_type_range resTypes = callOp->getResultTypes();
665 if (resTypes.size() != 1) {
667 return callOp->emitOpError().append(
671 return checkSelfTypeUnknownTarget(
672 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
674 }
else if (CalleeKind::Constrain == tgtKind) {
676 if (callOp->getNumResults() != 0) {
678 return callOp->emitOpError()
685 LogicalResult verifyAffineMapParams()
override {
686 if (CalleeKind::Compute == tgtKind || CalleeKind::Product == tgtKind) {
688 }
else if (CalleeKind::Constrain == tgtKind) {
690 return verifyNoAffineMapInstantiations();
696 SymbolRefAttr calleeAttr;
710 return emitOpError(
"requires a 'callee' symbol reference attribute");
715 if (calleeAttr.getNestedReferences().size() == 1) {
717 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
718 return UnknownTargetVerifier(
this, calleeAttr).verify();
725 if (failed(tgtOpt)) {
727 << calleeAttr <<
'"';
729 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
733 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
738bool calleeIsStructFunctionImpl(
739 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
741 if (callee.getLeafReference() == funcName) {
769 return getResults().front();
778 Operation *thisOp = this->getOperation();
780 assert(succeeded(root));
798 llvm::SmallVector<ValueRange, 4> output;
799 output.reserve(input.size());
800 for (OperandRange r : input) {