LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
AffineHelper.cpp
Go to the documentation of this file.
1//===-- AffineHelper.cpp ---------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
12#include <numeric>
13
14using namespace mlir;
15
17
18namespace {
19
20ParseResult parseDimAndSymbolListImpl(
21 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,
22 int32_t &numDims
23) {
24 // Parse the required dimension operands.
25 if (parser.parseOperandList(mapOperands, OpAsmParser::Delimiter::Paren)) {
26 return failure();
27 }
28 // Store number of dimensions for validation by caller.
29 numDims = mapOperands.size();
30
31 // Parse the optional symbol operands.
32 return parser.parseOperandList(mapOperands, OpAsmParser::Delimiter::OptionalSquare);
33}
34
35void printDimAndSymbolListImpl(
36 OpAsmPrinter &printer, Operation *op, OperandRange mapOperands, size_t numDims
37) {
38 printer << '(' << mapOperands.take_front(numDims) << ')';
39 if (mapOperands.size() > numDims) {
40 printer << '[' << mapOperands.drop_front(numDims) << ']';
41 }
42}
43} // namespace
44
46 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,
47 IntegerAttr &numDims
48) {
49 int32_t numDimsRes = -1;
50 ParseResult res = parseDimAndSymbolListImpl(parser, mapOperands, numDimsRes);
51 numDims = parser.getBuilder().getIndexAttr(numDimsRes);
52 return res;
53}
54
56 OpAsmPrinter &printer, Operation *op, OperandRange mapOperands, IntegerAttr numDims
57) {
58 printDimAndSymbolListImpl(printer, op, mapOperands, numDims.getInt());
59}
60
62 OpAsmParser &parser,
63 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &multiMapOperands,
64 DenseI32ArrayAttr &numDimsPerMap
65) {
66 SmallVector<int32_t> numDimsPerMapRes;
67 auto parseEach = [&]() -> ParseResult {
68 SmallVector<OpAsmParser::UnresolvedOperand> nextMapOps;
69 int32_t nextMapDims = -1;
70 ParseResult res = parseDimAndSymbolListImpl(parser, nextMapOps, nextMapDims);
71 numDimsPerMapRes.push_back(nextMapDims);
72 multiMapOperands.push_back(nextMapOps);
73 return res;
74 };
75 ParseResult res = parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseEach);
76
77 numDimsPerMap = parser.getBuilder().getDenseI32ArrayAttr(numDimsPerMapRes);
78 return res;
79}
80
82 OpAsmPrinter &printer, Operation *op, OperandRangeRange multiMapOperands,
83 DenseI32ArrayAttr numDimsPerMap
84) {
85 size_t count = numDimsPerMap.size();
86 assert(multiMapOperands.size() == count);
87 llvm::interleaveComma(llvm::seq<size_t>(0, count), printer.getStream(), [&](size_t i) {
88 printDimAndSymbolListImpl(printer, op, multiMapOperands[i], numDimsPerMap[i]);
89 });
90}
91
92ParseResult
93parseAttrDictWithWarnings(OpAsmParser &parser, NamedAttrList &extraAttrs, OperationState &state) {
94 // Replicate what ODS generates w/o the custom<AttrDictWithWarnings> directive
95 llvm::SMLoc loc = parser.getCurrentLocation();
96 if (parser.parseOptionalAttrDict(extraAttrs)) {
97 return failure();
98 }
99 if (failed(state.name.verifyInherentAttrs(extraAttrs, [&]() {
100 return parser.emitError(loc) << "'" << state.name.getStringRef() << "' op ";
101 }))) {
102 return failure();
103 }
104 // Ignore, with warnings, any attributes that are specified and shouldn't be
105 for (StringAttr skipName : state.name.getAttributeNames()) {
106 if (extraAttrs.erase(skipName)) {
107 auto msg =
108 "Ignoring attribute '" + Twine(skipName) + "' because it must be computed automatically.";
109 mlir::emitWarning(parser.getEncodedSourceLoc(loc), msg).report();
110 }
111 }
112 // There is no failure from this last check, only warnings
113 return success();
114}
115
116namespace {
117inline InFlightDiagnostic msgInstantiationGroupAttrMismatch(
118 Operation *op, size_t mapOpGroupSizesCount, size_t mapOperandsSize
119) {
120 return op->emitOpError().append(
121 "map instantiation group count (", mapOperandsSize,
122 ") does not match with length of 'mapOpGroupSizes' attribute (", mapOpGroupSizesCount, ")"
123 );
124}
125} // namespace
126
128 Operation *op, int32_t segmentSize, ArrayRef<int32_t> mapOpGroupSizes,
129 OperandRangeRange mapOperands, ArrayRef<int32_t> numDimsPerMap
130) {
131 // Ensure the `mapOpGroupSizes` and `operandSegmentSizes` attributes agree.
132 // NOTE: the ODS generates verifyValueSizeAttr() which ensures 'mapOpGroupSizes' has no negative
133 // elements and its sum is equal to the operand group size (which is similar to this check).
134 // If segmentSize < 0 the check is validated regardless of the difference.
135 int32_t totalMapOpGroupSizes = std::reduce(mapOpGroupSizes.begin(), mapOpGroupSizes.end());
136 if (totalMapOpGroupSizes != segmentSize && segmentSize >= 0) {
137 // Since `mapOpGroupSizes` and `segmentSize` are computed this should never happen.
138 return op->emitOpError().append(
139 "number of operands for affine map instantiation (", totalMapOpGroupSizes,
140 ") does not match with the total size (", segmentSize,
141 ") specified in attribute 'operandSegmentSizes'"
142 );
143 }
144
145 // Ensure the size of `mapOperands` and its two list attributes are the same.
146 // This will be true if the op was constructed via parseMultiDimAndSymbolList()
147 // but when constructed via the build() API, it can be inconsistent.
148 size_t count = mapOpGroupSizes.size();
149 if (mapOperands.size() != count) {
150 return msgInstantiationGroupAttrMismatch(op, count, mapOperands.size());
151 }
152 if (numDimsPerMap.size() != count) {
153 // Tested in CallOpTests.cpp
154 return op->emitOpError().append(
155 "length of 'numDimsPerMap' attribute (", numDimsPerMap.size(),
156 ") does not match with length of 'mapOpGroupSizes' attribute (", count, ")"
157 );
158 }
159
160 // Verify the following:
161 // 1. 'mapOperands' element sizes match 'mapOpGroupSizes' values
162 // 2. each 'numDimsPerMap' is <= corresponding 'mapOpGroupSizes'
163 LogicalResult aggregateResult = success();
164 for (size_t i = 0; i < count; ++i) {
165 auto currMapOpGroupSize = mapOpGroupSizes[i];
166 if (std::cmp_not_equal(mapOperands[i].size(), currMapOpGroupSize)) {
167 // Since `mapOpGroupSizes` is computed this should never happen.
168 aggregateResult = op->emitOpError().append(
169 "map instantiation group ", i, " operand count (", mapOperands[i].size(),
170 ") does not match group ", i, " size in 'mapOpGroupSizes' attribute (",
171 currMapOpGroupSize, ")"
172 );
173 } else if (std::cmp_greater(numDimsPerMap[i], currMapOpGroupSize)) {
174 // Tested in CallOpTests.cpp
175 aggregateResult = op->emitOpError().append(
176 "map instantiation group ", i, " dimension count (", numDimsPerMap[i], ") exceeds group ",
177 i, " size in 'mapOpGroupSizes' attribute (", currMapOpGroupSize, ")"
178 );
179 }
180 }
181 return aggregateResult;
182}
183
185 OperandRangeRange mapOps, ArrayRef<int32_t> numDimsPerMap, ArrayRef<AffineMapAttr> mapAttrs,
186 Operation *origin
187) {
188 size_t count = numDimsPerMap.size();
189 if (mapOps.size() != count) {
190 return msgInstantiationGroupAttrMismatch(origin, count, mapOps.size());
191 }
192
193 // Ensure there is one OperandRange for each AffineMapAttr
194 if (mapAttrs.size() != count) {
195 // Tested in array_build_fail.llzk, call_with_affinemap_fail.llzk, CallOpTests.cpp, and
196 // CreateArrayOpTests.cpp
197 return origin->emitOpError().append(
198 "map instantiation group count (", count,
199 ") does not match the number of affine map instantiations (", mapAttrs.size(),
200 ") required by the type"
201 );
202 }
203
204 // Ensure the affine map identifier counts match the instantiation.
205 // Rather than immediately returning on failure, we check all dimensions and aggregate to provide
206 // as many errors are possible in a single verifier run.
207 LogicalResult aggregateResult = success();
208 for (size_t i = 0; i < count; ++i) {
209 AffineMap map = mapAttrs[i].getAffineMap();
210 if (std::cmp_not_equal(map.getNumDims(), numDimsPerMap[i])) {
211 // Tested in array_build_fail.llzk and call_with_affinemap_fail.llzk
212 aggregateResult = origin->emitOpError().append(
213 "instantiation of map ", i, " expected ", map.getNumDims(), " but found ",
214 numDimsPerMap[i], " dimension values in ()"
215 );
216 } else if (std::cmp_not_equal(map.getNumInputs(), mapOps[i].size())) {
217 // Tested in array_build_fail.llzk and call_with_affinemap_fail.llzk
218 aggregateResult = origin->emitOpError().append(
219 "instantiation of map ", i, " expected ", map.getNumSymbols(), " but found ",
220 (mapOps[i].size() - numDimsPerMap[i]), " symbol values in []"
221 );
222 }
223 }
224 return aggregateResult;
225}
226
227} // namespace llzk::affineMapHelpers
Group together all implementation related to AffineMap type parameters.
LogicalResult verifySizesForMultiAffineOps(Operation *op, int32_t segmentSize, ArrayRef< int32_t > mapOpGroupSizes, OperandRangeRange mapOperands, ArrayRef< int32_t > numDimsPerMap)
ParseResult parseMultiDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &multiMapOperands, DenseI32ArrayAttr &numDimsPerMap)
void printDimAndSymbolList(OpAsmPrinter &printer, Operation *op, OperandRange mapOperands, IntegerAttr numDims)
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, IntegerAttr &numDims)
ParseResult parseAttrDictWithWarnings(OpAsmParser &parser, NamedAttrList &extraAttrs, OperationState &state)
void printMultiDimAndSymbolList(OpAsmPrinter &printer, Operation *op, OperandRangeRange multiMapOperands, DenseI32ArrayAttr numDimsPerMap)