45 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
55 Location location, StringRef
name, FunctionType type, ArrayRef<NamedAttribute> attrs
61 Location location, StringRef
name, FunctionType type, Operation::dialect_attr_range attrs
63 SmallVector<NamedAttribute, 8> attrRef(attrs);
64 return create(location,
name, type, llvm::ArrayRef(attrRef));
68 Location location, StringRef
name, FunctionType type, ArrayRef<NamedAttribute> attrs,
69 ArrayRef<DictionaryAttr> argAttrs
72 func.setAllArgAttrs(argAttrs);
77 OpBuilder &builder, OperationState &state, StringRef
name, FunctionType type,
78 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
80 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(
name));
82 state.attributes.append(attrs.begin(), attrs.end());
85 if (argAttrs.empty()) {
88 assert(type.getNumInputs() == argAttrs.size());
89 function_interface_impl::addArgAndResultAttrs(
96 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
97 function_interface_impl::VariadicFlag,
98 std::string &) {
return builder.getFunctionType(argTypes, results); };
100 return function_interface_impl::parseFunctionOp(
107 function_interface_impl::printFunctionOp(
117 llvm::MapVector<StringAttr, Attribute> newAttrMap;
118 for (
const auto &attr : dest->getAttrs()) {
119 newAttrMap.insert({attr.getName(), attr.getValue()});
121 for (
const auto &attr : (*this)->getAttrs()) {
122 newAttrMap.insert({attr.getName(), attr.getValue()});
126 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
127 return NamedAttribute(attrPair.first, attrPair.second);
129 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
142 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
150 unsigned oldNumArgs = oldType.getNumInputs();
151 SmallVector<Type, 4> newInputs;
152 newInputs.reserve(oldNumArgs);
153 for (
unsigned i = 0; i != oldNumArgs; ++i) {
154 if (!mapper.contains(getArgument(i))) {
155 newInputs.push_back(oldType.getInput(i));
161 if (newInputs.size() != oldNumArgs) {
162 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
164 if (ArrayAttr argAttrs = getAllArgAttrs()) {
165 SmallVector<Attribute> newArgAttrs;
166 newArgAttrs.reserve(newInputs.size());
167 for (
unsigned i = 0; i != oldNumArgs; ++i) {
168 if (!mapper.contains(getArgument(i))) {
169 newArgAttrs.push_back(argAttrs[i]);
172 newFunc.setAllArgAttrs(newArgAttrs);
184 return clone(mapper);
189 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
191 getOperation()->removeAttr(AllowConstraintAttr::name);
197 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
199 getOperation()->removeAttr(AllowWitnessAttr::name);
204 if (index < this->getNumArguments()) {
205 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
206 return res ? res.contains(PublicAttr::name) :
false;
219 llvm::ArrayRef<Type> inTypes = type.getInputs();
220 for (
auto ptr = inTypes.begin(); ptr < inTypes.end(); ptr++) {
225 emitErrorFunc().append(
226 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", *ptr
231 llvm::ArrayRef<Type> resTypes = type.getResults();
232 for (
auto ptr = resTypes.begin(); ptr < resTypes.end(); ptr++) {
245 llvm::ArrayRef<Type> resTypes = funcType.getResults();
247 if (resTypes.size() != 1) {
248 return origin.emitOpError().append(
252 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
263verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
266 if (funcType.getResults().size() != 0) {
267 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
271 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
272 if (inputTypes.size() < 1) {
274 <<
"\" must have at least one input type";
276 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
292 if (succeeded(parentStructOpt)) {
295 return verifyFuncTypeCompute(*
this, tables, parentStructOpt.value());
297 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt.value());
306 if (!requireParent && getOperation()->getParentOp() ==
nullptr) {
307 return SymbolRefAttr::get(getOperation());
310 assert(succeeded(res));
318 assert(!body.empty() &&
"compute() function body is empty");
319 Block &block = body.back();
322 Operation *terminator = block.getTerminator();
323 assert(terminator &&
"compute() function has no terminator");
324 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
327 << terminator->getName() <<
"'\n";
328 llvm_unreachable(
"compute() function must end with ReturnOp");
330 return retOp.getOperands().front();
335 return getArguments().front();
348 auto function = getParentOp<FuncDefOp>();
351 const auto results =
function.getFunctionType().getResults();
352 if (getNumOperands() != results.size()) {
353 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
354 <<
function.getName() <<
") returns " << results.size();
357 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
358 if (!
typesUnify(getOperand(i).getType(), results[i])) {
359 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
360 <<
") doesn't match function result type (" << results[i] <<
")"
361 <<
" in function @" <<
function.getName();
373 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
374 ValueRange argOperands
376 odsState.addTypes(resultTypes);
377 odsState.addOperands(argOperands);
379 odsBuilder, odsState,
static_cast<int32_t
>(argOperands.size())
381 props.setCallee(callee);
385 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
386 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
388 odsState.addTypes(resultTypes);
389 odsState.addOperands(argOperands);
391 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
393 props.setCallee(callee);
397enum class CalleeKind { Compute, Constrain, Other };
399CalleeKind calleeNameToKind(StringRef tgtName) {
401 return CalleeKind::Compute;
403 return CalleeKind::Constrain;
405 return CalleeKind::Other;
409struct CallOpVerifier {
410 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(calleeNameToKind(tgtName)) {}
411 virtual ~CallOpVerifier() =
default;
413 LogicalResult verify() {
416 LogicalResult aggregateResult = success();
417 if (failed(verifyTargetAttributes())) {
418 aggregateResult = failure();
420 if (failed(verifyInputs())) {
421 aggregateResult = failure();
423 if (failed(verifyOutputs())) {
424 aggregateResult = failure();
426 if (failed(verifyAffineMapParams())) {
427 aggregateResult = failure();
429 return aggregateResult;
436 virtual LogicalResult verifyTargetAttributes() = 0;
437 virtual LogicalResult verifyInputs() = 0;
438 virtual LogicalResult verifyOutputs() = 0;
439 virtual LogicalResult verifyAffineMapParams() = 0;
442 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
443 LogicalResult aggregateRes = success();
444 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
445 auto emitAttrErr = [&](StringLiteral attrName) {
446 aggregateRes = callOp->emitOpError()
447 <<
"target '@" << target.getName() <<
"' has '" << attrName
448 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
453 emitAttrErr(AllowConstraintAttr::name);
456 emitAttrErr(AllowWitnessAttr::name);
462 LogicalResult verifyNoAffineMapInstantiations() {
465 return callOp->emitOpError().append(
466 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
472 assert(callOp->getMapOperands().empty());
477struct KnownTargetVerifier :
public CallOpVerifier {
478 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
479 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
480 includeSymNames(tgtRes.getIncludeSymNames()) {}
482 LogicalResult verifyTargetAttributes()
override {
483 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
486 LogicalResult verifyInputs()
override {
487 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
490 LogicalResult verifyOutputs()
override {
491 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
494 LogicalResult verifyAffineMapParams()
override {
495 if (CalleeKind::Compute == tgtKind &&
isInStruct(tgt.getOperation())) {
501 if (ArrayAttr params = retTy.getParams()) {
503 SmallVector<AffineMapAttr> mapAttrs;
504 for (Attribute a : params) {
505 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
506 mapAttrs.push_back(m);
517 return verifyNoAffineMapInstantiations();
522 template <
typename T>
524 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
525 if (tgtTypes.size() != callOpTypes.size()) {
526 return callOp->emitOpError()
527 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
528 .attachNote(tgt.getLoc())
529 .append(
"callee defined here");
531 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
532 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
533 return callOp->emitOpError().append(
534 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
535 " for ", aspect,
" number ", i
543 FunctionType tgtType;
544 std::vector<llvm::StringRef> includeSymNames;
549LogicalResult checkSelfTypeUnknownTarget(
550 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
552 if (!llvm::isa<TypeVarType>(actualType) ||
553 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
559 return origin->emitOpError().append(
560 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
561 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
577struct UnknownTargetVerifier :
public CallOpVerifier {
578 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
579 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
581 LogicalResult verifyTargetAttributes()
override {
584 LogicalResult aggregateRes = success();
585 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
586 auto emitAttrErr = [&](StringLiteral attrName) {
587 aggregateRes = callOp->emitOpError()
588 <<
"target '" << calleeAttr <<
"' has '" << attrName
589 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
593 if (tgtKind == CalleeKind::Constrain && !caller.hasAllowConstraintAttr()) {
594 emitAttrErr(AllowConstraintAttr::name);
596 if (tgtKind == CalleeKind::Compute && !caller.hasAllowWitnessAttr()) {
597 emitAttrErr(AllowWitnessAttr::name);
603 LogicalResult verifyInputs()
override {
604 if (CalleeKind::Compute == tgtKind) {
606 }
else if (CalleeKind::Constrain == tgtKind) {
609 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
610 if (inputTypes.size() < 1) {
612 return callOp->emitOpError()
615 return checkSelfTypeUnknownTarget(
616 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
622 LogicalResult verifyOutputs()
override {
623 if (CalleeKind::Compute == tgtKind) {
626 Operation::result_type_range resTypes = callOp->getResultTypes();
627 if (resTypes.size() != 1) {
629 return callOp->emitOpError().append(
633 return checkSelfTypeUnknownTarget(
634 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
636 }
else if (CalleeKind::Constrain == tgtKind) {
638 if (callOp->getNumResults() != 0) {
640 return callOp->emitOpError()
647 LogicalResult verifyAffineMapParams()
override {
648 if (CalleeKind::Compute == tgtKind) {
650 }
else if (CalleeKind::Constrain == tgtKind) {
652 return verifyNoAffineMapInstantiations();
658 SymbolRefAttr calleeAttr;
672 return emitOpError(
"requires a 'callee' symbol reference attribute");
677 if (calleeAttr.getNestedReferences().size() == 1) {
679 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
680 return UnknownTargetVerifier(
this, calleeAttr).verify();
687 if (failed(tgtOpt)) {
689 << calleeAttr <<
'"';
691 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
695 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
700bool calleeIsStructFunctionImpl(
701 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
703 if (callee.getLeafReference() == funcName) {
731 return getResults().front();
740 Operation *thisOp = this->getOperation();
742 assert(succeeded(root));
760 llvm::SmallVector<ValueRange, 4> output;
761 output.reserve(input.size());
762 for (OperandRange r : input) {