58 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
68 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
74 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
76 SmallVector<NamedAttribute, 8> attrRef(attrs);
77 return create(location, name, type, llvm::ArrayRef(attrRef));
81 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
82 ArrayRef<DictionaryAttr> argAttrs
84 FuncDefOp func =
create(location, name, type, attrs);
85 func.setAllArgAttrs(argAttrs);
90 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
91 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
93 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
95 state.attributes.append(attrs.begin(), attrs.end());
98 if (argAttrs.empty()) {
101 assert(type.getNumInputs() == argAttrs.size());
102 function_interface_impl::addArgAndResultAttrs(
109 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
110 function_interface_impl::VariadicFlag,
111 std::string &) {
return builder.getFunctionType(argTypes, results); };
113 return function_interface_impl::parseFunctionOp(
120 function_interface_impl::printFunctionOp(
130 llvm::MapVector<StringAttr, Attribute> newAttrMap;
131 for (
const auto &attr : dest->getAttrs()) {
132 newAttrMap.insert({attr.getName(), attr.getValue()});
134 for (
const auto &attr : (*this)->getAttrs()) {
135 newAttrMap.insert({attr.getName(), attr.getValue()});
139 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
140 return NamedAttribute(attrPair.first, attrPair.second);
142 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
155 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
163 unsigned oldNumArgs = oldType.getNumInputs();
164 SmallVector<Type, 4> newInputs;
165 newInputs.reserve(oldNumArgs);
166 for (
unsigned i = 0; i != oldNumArgs; ++i) {
167 if (!mapper.contains(getArgument(i))) {
168 newInputs.push_back(oldType.getInput(i));
174 if (newInputs.size() != oldNumArgs) {
175 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
177 if (ArrayAttr argAttrs = getAllArgAttrs()) {
178 SmallVector<Attribute> newArgAttrs;
179 newArgAttrs.reserve(newInputs.size());
180 for (
unsigned i = 0; i != oldNumArgs; ++i) {
181 if (!mapper.contains(getArgument(i))) {
182 newArgAttrs.push_back(argAttrs[i]);
185 newFunc.setAllArgAttrs(newArgAttrs);
197 return clone(mapper);
202 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
204 getOperation()->removeAttr(AllowConstraintAttr::name);
210 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
212 getOperation()->removeAttr(AllowWitnessAttr::name);
217 if (index < this->getNumArguments()) {
218 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
219 return res ? res.contains(PublicAttr::name) :
false;
233 for (Type t : type.getInputs()) {
238 return emitErrorFunc().append(
239 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", t
243 for (Type t : type.getResults()) {
251 WalkResult res = this->walk<WalkOrder::PreOrder>([
this](ModuleOp nestedMod) {
253 "cannot be nested within '", getOperation()->getName(),
"' operations"
255 return WalkResult::interrupt();
257 if (res.wasInterrupted()) {
269 llvm::ArrayRef<Type> resTypes = funcType.getResults();
271 if (resTypes.size() != 1) {
272 return origin.emitOpError().append(
276 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
287verifyFuncTypeProduct(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
289 return verifyFuncTypeCompute(origin, tables, parent);
293verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
296 if (funcType.getResults().size() != 0) {
297 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
301 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
302 if (inputTypes.size() < 1) {
304 <<
"\" must have at least one input type";
306 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
322 if (succeeded(parentStructOpt)) {
325 return verifyFuncTypeCompute(*
this, tables, parentStructOpt.value());
327 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt.value());
329 return verifyFuncTypeProduct(*
this, tables, parentStructOpt.value());
338 if (!requireParent && getOperation()->getParentOp() ==
nullptr) {
339 return SymbolRefAttr::get(getOperation());
342 assert(succeeded(res));
350 assert(!body.empty() &&
"compute() function body is empty");
351 Block &block = body.back();
354 Operation *terminator = block.getTerminator();
355 assert(terminator &&
"compute() function has no terminator");
356 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
359 << terminator->getName() <<
"'\n";
360 llvm_unreachable(
"compute() function must end with ReturnOp");
362 return retOp.getOperands().front();
367 return getArguments().front();
380 auto function = getParentOp<FuncDefOp>();
383 const auto results =
function.getFunctionType().getResults();
384 if (getNumOperands() != results.size()) {
385 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
386 <<
function.getName() <<
") returns " << results.size();
389 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
390 if (!
typesUnify(getOperand(i).getType(), results[i])) {
391 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
392 <<
") doesn't match function result type (" << results[i] <<
")"
393 <<
" in function @" <<
function.getName();
405 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
406 ValueRange argOperands
408 odsState.addTypes(resultTypes);
409 odsState.addOperands(argOperands);
411 odsBuilder, odsState,
static_cast<int32_t
>(argOperands.size())
413 props.setCallee(callee);
417 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
418 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
420 odsState.addTypes(resultTypes);
421 odsState.addOperands(argOperands);
423 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
425 props.setCallee(callee);
430struct CallOpVerifier {
431 CallOpVerifier(
CallOp *c, StringRef tgtName) : callOp(c), tgtKind(
fnNameToKind(tgtName)) {}
432 virtual ~CallOpVerifier() =
default;
434 LogicalResult verify() {
437 LogicalResult aggregateResult = success();
438 if (failed(verifyTargetAttributes())) {
439 aggregateResult = failure();
441 if (failed(verifyInputs())) {
442 aggregateResult = failure();
444 if (failed(verifyOutputs())) {
445 aggregateResult = failure();
447 if (failed(verifyAffineMapParams())) {
448 aggregateResult = failure();
450 return aggregateResult;
457 virtual LogicalResult verifyTargetAttributes() = 0;
458 virtual LogicalResult verifyInputs() = 0;
459 virtual LogicalResult verifyOutputs() = 0;
460 virtual LogicalResult verifyAffineMapParams() = 0;
463 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
464 LogicalResult aggregateRes = success();
465 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
466 auto emitAttrErr = [&](StringLiteral attrName) {
467 aggregateRes = callOp->emitOpError()
468 <<
"target '@" << target.getName() <<
"' has '" << attrName
469 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
474 emitAttrErr(AllowConstraintAttr::name);
477 emitAttrErr(AllowWitnessAttr::name);
483 LogicalResult verifyNoAffineMapInstantiations() {
486 return callOp->emitOpError().append(
487 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
493 assert(callOp->getMapOperands().empty());
498struct KnownTargetVerifier :
public CallOpVerifier {
499 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
500 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
501 includeSymNames(tgtRes.getIncludeSymNames()) {}
503 LogicalResult verifyTargetAttributes()
override {
504 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
507 LogicalResult verifyInputs()
override {
508 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
511 LogicalResult verifyOutputs()
override {
512 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
515 LogicalResult verifyAffineMapParams()
override {
523 if (ArrayAttr params = retTy.getParams()) {
525 SmallVector<AffineMapAttr> mapAttrs;
526 for (Attribute a : params) {
527 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
528 mapAttrs.push_back(m);
539 return verifyNoAffineMapInstantiations();
544 template <
typename T>
546 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
547 if (tgtTypes.size() != callOpTypes.size()) {
548 return callOp->emitOpError()
549 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
550 .attachNote(tgt.getLoc())
551 .append(
"callee defined here");
553 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
554 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
555 return callOp->emitOpError().append(
556 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
557 " for ", aspect,
" number ", i
565 FunctionType tgtType;
566 std::vector<llvm::StringRef> includeSymNames;
571LogicalResult checkSelfTypeUnknownTarget(
572 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
574 if (!llvm::isa<TypeVarType>(actualType) ||
575 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
581 return origin->emitOpError().append(
582 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
583 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
599struct UnknownTargetVerifier :
public CallOpVerifier {
600 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
601 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
603 LogicalResult verifyTargetAttributes()
override {
606 LogicalResult aggregateRes = success();
607 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
608 auto emitAttrErr = [&](StringLiteral attrName) {
609 aggregateRes = callOp->emitOpError()
610 <<
"target '" << calleeAttr <<
"' has '" << attrName
611 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
617 if (!caller.hasAllowConstraintAttr()) {
618 emitAttrErr(AllowConstraintAttr::name);
622 if (!caller.hasAllowWitnessAttr()) {
623 emitAttrErr(AllowWitnessAttr::name);
627 if (!caller.hasAllowWitnessAttr()) {
628 emitAttrErr(AllowWitnessAttr::name);
630 if (!caller.hasAllowConstraintAttr()) {
631 emitAttrErr(AllowConstraintAttr::name);
641 LogicalResult verifyInputs()
override {
647 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
648 if (inputTypes.size() < 1) {
650 return callOp->emitOpError()
653 return checkSelfTypeUnknownTarget(
654 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
660 LogicalResult verifyOutputs()
override {
664 Operation::result_type_range resTypes = callOp->getResultTypes();
665 if (resTypes.size() != 1) {
667 return callOp->emitOpError().append(
671 return checkSelfTypeUnknownTarget(
672 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
676 if (callOp->getNumResults() != 0) {
678 return callOp->emitOpError()
685 LogicalResult verifyAffineMapParams()
override {
690 return verifyNoAffineMapInstantiations();
696 SymbolRefAttr calleeAttr;
710 return emitOpError(
"requires a 'callee' symbol reference attribute");
715 if (calleeAttr.getNestedReferences().size() == 1) {
717 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
718 return UnknownTargetVerifier(
this, calleeAttr).verify();
725 if (failed(tgtOpt)) {
727 << calleeAttr <<
'"';
729 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
733 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
738bool calleeIsStructFunctionImpl(
739 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
741 if (callee.getLeafReference() == funcName) {
769 return getResults().front();
778 Operation *thisOp = this->getOperation();
780 assert(succeeded(root));
803 llvm::SmallVector<ValueRange, 4> output;
804 output.reserve(input.size());
805 for (OperandRange r : input) {