44 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
54 Location location, StringRef
name, FunctionType type, ArrayRef<NamedAttribute> attrs
60 Location location, StringRef
name, FunctionType type, Operation::dialect_attr_range attrs
62 SmallVector<NamedAttribute, 8> attrRef(attrs);
63 return create(location,
name, type, llvm::ArrayRef(attrRef));
67 Location location, StringRef
name, FunctionType type, ArrayRef<NamedAttribute> attrs,
68 ArrayRef<DictionaryAttr> argAttrs
71 func.setAllArgAttrs(argAttrs);
76 OpBuilder &builder, OperationState &state, StringRef
name, FunctionType type,
77 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
79 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(
name));
81 state.attributes.append(attrs.begin(), attrs.end());
84 if (argAttrs.empty()) {
87 assert(type.getNumInputs() == argAttrs.size());
88 function_interface_impl::addArgAndResultAttrs(
95 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
96 function_interface_impl::VariadicFlag,
97 std::string &) {
return builder.getFunctionType(argTypes, results); };
99 return function_interface_impl::parseFunctionOp(
106 function_interface_impl::printFunctionOp(
116 llvm::MapVector<StringAttr, Attribute> newAttrMap;
117 for (
const auto &attr : dest->getAttrs()) {
118 newAttrMap.insert({attr.getName(), attr.getValue()});
120 for (
const auto &attr : (*this)->getAttrs()) {
121 newAttrMap.insert({attr.getName(), attr.getValue()});
125 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
126 return NamedAttribute(attrPair.first, attrPair.second);
128 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
141 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
149 unsigned oldNumArgs = oldType.getNumInputs();
150 SmallVector<Type, 4> newInputs;
151 newInputs.reserve(oldNumArgs);
152 for (
unsigned i = 0; i != oldNumArgs; ++i) {
153 if (!mapper.contains(getArgument(i))) {
154 newInputs.push_back(oldType.getInput(i));
160 if (newInputs.size() != oldNumArgs) {
161 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
163 if (ArrayAttr argAttrs = getAllArgAttrs()) {
164 SmallVector<Attribute> newArgAttrs;
165 newArgAttrs.reserve(newInputs.size());
166 for (
unsigned i = 0; i != oldNumArgs; ++i) {
167 if (!mapper.contains(getArgument(i))) {
168 newArgAttrs.push_back(argAttrs[i]);
171 newFunc.setAllArgAttrs(newArgAttrs);
183 return clone(mapper);
188 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
190 getOperation()->removeAttr(AllowConstraintAttr::name);
196 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
198 getOperation()->removeAttr(AllowWitnessAttr::name);
203 if (index < this->getNumArguments()) {
204 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
205 return res ? res.contains(PublicAttr::name) :
false;
218 llvm::ArrayRef<Type> inTypes = type.getInputs();
219 for (
auto ptr = inTypes.begin(); ptr < inTypes.end(); ptr++) {
224 emitErrorFunc().append(
225 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", *ptr
230 llvm::ArrayRef<Type> resTypes = type.getResults();
231 for (
auto ptr = resTypes.begin(); ptr < resTypes.end(); ptr++) {
244 llvm::ArrayRef<Type> resTypes = funcType.getResults();
246 if (resTypes.size() != 1) {
247 return origin.emitOpError().append(
251 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
262verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
265 if (funcType.getResults().size() != 0) {
266 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
270 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
271 if (inputTypes.size() < 1) {
273 <<
"\" must have at least one input type";
275 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
291 if (succeeded(parentStructOpt)) {
294 return verifyFuncTypeCompute(*
this, tables, parentStructOpt.value());
296 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt.value());
305 assert(succeeded(res));
319 auto function = getParentOp<FuncDefOp>();
322 const auto results =
function.getFunctionType().getResults();
323 if (getNumOperands() != results.size()) {
324 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
325 <<
function.getName() <<
") returns " << results.size();
328 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
329 if (!
typesUnify(getOperand(i).getType(), results[i])) {
330 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
331 <<
") doesn't match function result type (" << results[i] <<
")"
332 <<
" in function @" <<
function.getName();
344 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
345 ValueRange argOperands
347 odsState.addTypes(resultTypes);
348 odsState.addOperands(argOperands);
350 odsBuilder, odsState,
static_cast<int32_t
>(argOperands.size())
352 props.setCallee(callee);
356 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
357 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
359 odsState.addTypes(resultTypes);
360 odsState.addOperands(argOperands);
362 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
364 props.setCallee(callee);
368enum class CalleeKind { Compute, Constrain, Other };
370CalleeKind calleeNameToKind(StringRef tgtName) {
372 return CalleeKind::Compute;
374 return CalleeKind::Constrain;
376 return CalleeKind::Other;
380struct CallOpVerifier {
381 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(calleeNameToKind(tgtName)) {}
382 virtual ~CallOpVerifier() =
default;
384 LogicalResult verify() {
387 LogicalResult aggregateResult = success();
388 if (failed(verifyTargetAttributes())) {
389 aggregateResult = failure();
391 if (failed(verifyInputs())) {
392 aggregateResult = failure();
394 if (failed(verifyOutputs())) {
395 aggregateResult = failure();
397 if (failed(verifyAffineMapParams())) {
398 aggregateResult = failure();
400 return aggregateResult;
407 virtual LogicalResult verifyTargetAttributes() = 0;
408 virtual LogicalResult verifyInputs() = 0;
409 virtual LogicalResult verifyOutputs() = 0;
410 virtual LogicalResult verifyAffineMapParams() = 0;
413 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
414 LogicalResult aggregateRes = success();
415 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
416 auto emitAttrErr = [&](StringLiteral attrName) {
417 aggregateRes = callOp->emitOpError()
418 <<
"target '@" << target.getName() <<
"' has '" << attrName
419 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
424 emitAttrErr(AllowConstraintAttr::name);
427 emitAttrErr(AllowWitnessAttr::name);
433 LogicalResult verifyNoAffineMapInstantiations() {
436 return callOp->emitOpError().append(
437 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
443 assert(callOp->getMapOperands().empty());
448struct KnownTargetVerifier :
public CallOpVerifier {
449 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
450 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
451 includeSymNames(tgtRes.getIncludeSymNames()) {}
453 LogicalResult verifyTargetAttributes()
override {
454 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
457 LogicalResult verifyInputs()
override {
458 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
461 LogicalResult verifyOutputs()
override {
462 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
465 LogicalResult verifyAffineMapParams()
override {
466 if (CalleeKind::Compute == tgtKind &&
isInStruct(tgt.getOperation())) {
472 if (ArrayAttr params = retTy.getParams()) {
474 SmallVector<AffineMapAttr> mapAttrs;
475 for (Attribute a : params) {
476 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
477 mapAttrs.push_back(m);
488 return verifyNoAffineMapInstantiations();
493 template <
typename T>
495 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
496 if (tgtTypes.size() != callOpTypes.size()) {
497 return callOp->emitOpError()
498 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
499 .attachNote(tgt.getLoc())
500 .append(
"callee defined here");
502 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
503 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
504 return callOp->emitOpError().append(
505 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
506 " for ", aspect,
" number ", i
514 FunctionType tgtType;
515 std::vector<llvm::StringRef> includeSymNames;
520LogicalResult checkSelfTypeUnknownTarget(
521 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
523 if (!llvm::isa<TypeVarType>(actualType) ||
524 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
530 return origin->emitOpError().append(
531 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
532 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
548struct UnknownTargetVerifier :
public CallOpVerifier {
549 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
550 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
552 LogicalResult verifyTargetAttributes()
override {
555 LogicalResult aggregateRes = success();
556 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
557 auto emitAttrErr = [&](StringLiteral attrName) {
558 aggregateRes = callOp->emitOpError()
559 <<
"target '" << calleeAttr <<
"' has '" << attrName
560 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
564 if (tgtKind == CalleeKind::Constrain && !caller.hasAllowConstraintAttr()) {
565 emitAttrErr(AllowConstraintAttr::name);
567 if (tgtKind == CalleeKind::Compute && !caller.hasAllowWitnessAttr()) {
568 emitAttrErr(AllowWitnessAttr::name);
574 LogicalResult verifyInputs()
override {
575 if (CalleeKind::Compute == tgtKind) {
577 }
else if (CalleeKind::Constrain == tgtKind) {
580 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
581 if (inputTypes.size() < 1) {
583 return callOp->emitOpError()
586 return checkSelfTypeUnknownTarget(
587 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
593 LogicalResult verifyOutputs()
override {
594 if (CalleeKind::Compute == tgtKind) {
597 Operation::result_type_range resTypes = callOp->getResultTypes();
598 if (resTypes.size() != 1) {
600 return callOp->emitOpError().append(
604 return checkSelfTypeUnknownTarget(
605 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
607 }
else if (CalleeKind::Constrain == tgtKind) {
609 if (callOp->getNumResults() != 0) {
611 return callOp->emitOpError()
618 LogicalResult verifyAffineMapParams()
override {
619 if (CalleeKind::Compute == tgtKind) {
621 }
else if (CalleeKind::Constrain == tgtKind) {
623 return verifyNoAffineMapInstantiations();
629 SymbolRefAttr calleeAttr;
643 return emitOpError(
"requires a 'callee' symbol reference attribute");
648 if (calleeAttr.getNestedReferences().size() == 1) {
650 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
651 return UnknownTargetVerifier(
this, calleeAttr).verify();
658 if (failed(tgtOpt)) {
660 << calleeAttr <<
'"';
662 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
666 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
671bool calleeIsStructFunctionImpl(
672 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
674 if (callee.getLeafReference() == funcName) {