LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
AffineHelper.cpp
Go to the documentation of this file.
1//===-- AffineHelper.cpp ---------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
12#include <numeric>
13
14using namespace mlir;
15
17
18namespace {
19
20ParseResult parseDimAndSymbolListImpl(
21 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,
22 int32_t &numDims
23) {
24 // Parse the required dimension operands.
25 if (parser.parseOperandList(mapOperands, OpAsmParser::Delimiter::Paren)) {
26 return failure();
27 }
28 // Store number of dimensions for validation by caller.
29 numDims = mapOperands.size();
30
31 // Parse the optional symbol operands.
32 return parser.parseOperandList(mapOperands, OpAsmParser::Delimiter::OptionalSquare);
33}
34
35void printDimAndSymbolListImpl(OpAsmPrinter &printer, OperandRange mapOperands, size_t numDims) {
36 printer << '(' << mapOperands.take_front(numDims) << ')';
37 if (mapOperands.size() > numDims) {
38 printer << '[' << mapOperands.drop_front(numDims) << ']';
39 }
40}
41} // namespace
42
44 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,
45 IntegerAttr &numDims
46) {
47 int32_t numDimsRes = -1;
48 ParseResult res = parseDimAndSymbolListImpl(parser, mapOperands, numDimsRes);
49 numDims = parser.getBuilder().getIndexAttr(numDimsRes);
50 return res;
51}
52
54 OpAsmPrinter &printer, Operation *, OperandRange mapOperands, IntegerAttr numDims
55) {
56 printDimAndSymbolListImpl(printer, mapOperands, numDims.getInt());
57}
58
60 OpAsmParser &parser,
61 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &multiMapOperands,
62 DenseI32ArrayAttr &numDimsPerMap
63) {
64 SmallVector<int32_t> numDimsPerMapRes;
65 auto parseEach = [&]() -> ParseResult {
66 SmallVector<OpAsmParser::UnresolvedOperand> nextMapOps;
67 int32_t nextMapDims = -1;
68 ParseResult res = parseDimAndSymbolListImpl(parser, nextMapOps, nextMapDims);
69 numDimsPerMapRes.push_back(nextMapDims);
70 multiMapOperands.push_back(nextMapOps);
71 return res;
72 };
73 ParseResult res = parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseEach);
74
75 numDimsPerMap = parser.getBuilder().getDenseI32ArrayAttr(numDimsPerMapRes);
76 return res;
77}
78
80 OpAsmPrinter &printer, Operation *, OperandRangeRange multiMapOperands,
81 DenseI32ArrayAttr numDimsPerMap
82) {
83 size_t count = numDimsPerMap.size();
84 assert(multiMapOperands.size() == count);
85 llvm::interleaveComma(llvm::seq<size_t>(0, count), printer.getStream(), [&](size_t i) {
86 printDimAndSymbolListImpl(printer, multiMapOperands[i], numDimsPerMap[i]);
87 });
88}
89
90ParseResult
91parseAttrDictWithWarnings(OpAsmParser &parser, NamedAttrList &extraAttrs, OperationState &state) {
92 // Replicate what ODS generates w/o the custom<AttrDictWithWarnings> directive
93 llvm::SMLoc loc = parser.getCurrentLocation();
94 if (parser.parseOptionalAttrDict(extraAttrs)) {
95 return failure();
96 }
97 if (failed(state.name.verifyInherentAttrs(extraAttrs, [&]() {
98 return parser.emitError(loc) << '\'' << state.name.getStringRef() << "' op ";
99 }))) {
100 return failure();
101 }
102 // Ignore, with warnings, any attributes that are specified and shouldn't be
103 for (StringAttr skipName : state.name.getAttributeNames()) {
104 if (extraAttrs.erase(skipName)) {
105 auto msg =
106 "Ignoring attribute '" + Twine(skipName) + "' because it must be computed automatically.";
107 mlir::emitWarning(parser.getEncodedSourceLoc(loc), msg).report();
108 }
109 }
110 // There is no failure from this last check, only warnings
111 return success();
112}
113
114namespace {
115inline InFlightDiagnostic msgInstantiationGroupAttrMismatch(
116 Operation *op, size_t mapOpGroupSizesCount, size_t mapOperandsSize
117) {
118 return op->emitOpError().append(
119 "map instantiation group count (", mapOperandsSize,
120 ") does not match with length of 'mapOpGroupSizes' attribute (", mapOpGroupSizesCount, ")"
121 );
122}
123} // namespace
124
126 Operation *op, int32_t segmentSize, ArrayRef<int32_t> mapOpGroupSizes,
127 OperandRangeRange mapOperands, ArrayRef<int32_t> numDimsPerMap
128) {
129 // Ensure the `mapOpGroupSizes` and `operandSegmentSizes` attributes agree.
130 // NOTE: the ODS generates verifyValueSizeAttr() which ensures 'mapOpGroupSizes' has no negative
131 // elements and its sum is equal to the operand group size (which is similar to this check).
132 // If segmentSize < 0 the check is validated regardless of the difference.
133 int32_t totalMapOpGroupSizes = std::reduce(mapOpGroupSizes.begin(), mapOpGroupSizes.end());
134 if (totalMapOpGroupSizes != segmentSize && segmentSize >= 0) {
135 // Since `mapOpGroupSizes` and `segmentSize` are computed this should never happen.
136 return op->emitOpError().append(
137 "number of operands for affine map instantiation (", totalMapOpGroupSizes,
138 ") does not match with the total size (", segmentSize,
139 ") specified in attribute 'operandSegmentSizes'"
140 );
141 }
142
143 // Ensure the size of `mapOperands` and its two list attributes are the same.
144 // This will be true if the op was constructed via parseMultiDimAndSymbolList()
145 // but when constructed via the build() API, it can be inconsistent.
146 size_t count = mapOpGroupSizes.size();
147 if (mapOperands.size() != count) {
148 return msgInstantiationGroupAttrMismatch(op, count, mapOperands.size());
149 }
150 if (numDimsPerMap.size() != count) {
151 // Tested in CallOpTests.cpp
152 return op->emitOpError().append(
153 "length of 'numDimsPerMap' attribute (", numDimsPerMap.size(),
154 ") does not match with length of 'mapOpGroupSizes' attribute (", count, ")"
155 );
156 }
157
158 // Verify the following:
159 // 1. 'mapOperands' element sizes match 'mapOpGroupSizes' values
160 // 2. each 'numDimsPerMap' is <= corresponding 'mapOpGroupSizes'
161 LogicalResult aggregateResult = success();
162 for (size_t i = 0; i < count; ++i) {
163 auto currMapOpGroupSize = mapOpGroupSizes[i];
164 if (std::cmp_not_equal(mapOperands[i].size(), currMapOpGroupSize)) {
165 // Since `mapOpGroupSizes` is computed this should never happen.
166 aggregateResult = op->emitOpError().append(
167 "map instantiation group ", i, " operand count (", mapOperands[i].size(),
168 ") does not match group ", i, " size in 'mapOpGroupSizes' attribute (",
169 currMapOpGroupSize, ")"
170 );
171 } else if (std::cmp_greater(numDimsPerMap[i], currMapOpGroupSize)) {
172 // Tested in CallOpTests.cpp
173 aggregateResult = op->emitOpError().append(
174 "map instantiation group ", i, " dimension count (", numDimsPerMap[i], ") exceeds group ",
175 i, " size in 'mapOpGroupSizes' attribute (", currMapOpGroupSize, ")"
176 );
177 }
178 }
179 return aggregateResult;
180}
181
183 OperandRangeRange mapOps, ArrayRef<int32_t> numDimsPerMap, ArrayRef<AffineMapAttr> mapAttrs,
184 Operation *origin
185) {
186 size_t count = numDimsPerMap.size();
187 if (mapOps.size() != count) {
188 return msgInstantiationGroupAttrMismatch(origin, count, mapOps.size());
189 }
190
191 // Ensure there is one OperandRange for each AffineMapAttr
192 if (mapAttrs.size() != count) {
193 // Tested in array_build_fail.llzk, call_with_affinemap_fail.llzk, CallOpTests.cpp, and
194 // CreateArrayOpTests.cpp
195 return origin->emitOpError().append(
196 "map instantiation group count (", count,
197 ") does not match the number of affine map instantiations (", mapAttrs.size(),
198 ") required by the type"
199 );
200 }
201
202 // Ensure the affine map identifier counts match the instantiation.
203 // Rather than immediately returning on failure, we check all dimensions and aggregate to provide
204 // as many errors are possible in a single verifier run.
205 LogicalResult aggregateResult = success();
206 for (size_t i = 0; i < count; ++i) {
207 AffineMap map = mapAttrs[i].getAffineMap();
208 if (std::cmp_not_equal(map.getNumDims(), numDimsPerMap[i])) {
209 // Tested in array_build_fail.llzk and call_with_affinemap_fail.llzk
210 aggregateResult = origin->emitOpError().append(
211 "instantiation of map ", i, " expected ", map.getNumDims(), " but found ",
212 numDimsPerMap[i], " dimension values in ()"
213 );
214 } else if (std::cmp_not_equal(map.getNumInputs(), mapOps[i].size())) {
215 // Tested in array_build_fail.llzk and call_with_affinemap_fail.llzk
216 aggregateResult = origin->emitOpError().append(
217 "instantiation of map ", i, " expected ", map.getNumSymbols(), " but found ",
218 (mapOps[i].size() - numDimsPerMap[i]), " symbol values in []"
219 );
220 }
221 }
222 return aggregateResult;
223}
224
225} // namespace llzk::affineMapHelpers
Group together all implementation related to AffineMap type parameters.
LogicalResult verifySizesForMultiAffineOps(Operation *op, int32_t segmentSize, ArrayRef< int32_t > mapOpGroupSizes, OperandRangeRange mapOperands, ArrayRef< int32_t > numDimsPerMap)
ParseResult parseMultiDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &multiMapOperands, DenseI32ArrayAttr &numDimsPerMap)
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, IntegerAttr &numDims)
ParseResult parseAttrDictWithWarnings(OpAsmParser &parser, NamedAttrList &extraAttrs, OperationState &state)
void printMultiDimAndSymbolList(OpAsmPrinter &printer, Operation *, OperandRangeRange multiMapOperands, DenseI32ArrayAttr numDimsPerMap)
void printDimAndSymbolList(OpAsmPrinter &printer, Operation *, OperandRange mapOperands, IntegerAttr numDims)