LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SparseAnalysis.cpp
Go to the documentation of this file.
1//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp.
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//
14//===----------------------------------------------------------------------===//
15
20
21#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
22#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
23#include <mlir/Analysis/DataFlowFramework.h>
24#include <mlir/Dialect/SCF/IR/SCF.h>
25#include <mlir/IR/Attributes.h>
26#include <mlir/IR/Operation.h>
27#include <mlir/IR/Region.h>
28#include <mlir/IR/SymbolTable.h>
29#include <mlir/IR/Value.h>
30#include <mlir/IR/ValueRange.h>
31#include <mlir/Interfaces/CallInterfaces.h>
32#include <mlir/Interfaces/ControlFlowInterfaces.h>
33#include <mlir/Support/LLVM.h>
34
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/Support/Casting.h>
37
38#include <cassert>
39#include <optional>
40
41using namespace mlir;
42using namespace mlir::dataflow;
43using namespace llzk::function;
44
45namespace llzk::dataflow {
46
47//===----------------------------------------------------------------------===//
48// AbstractSparseForwardDataFlowAnalysis
49//===----------------------------------------------------------------------===//
50
52 : DataFlowAnalysis(s) {
53 registerAnchorKind<CFGEdge>();
54}
55
57 // Mark the entry block arguments as having reached their pessimistic
58 // fixpoints.
59 for (Region &region : top->getRegions()) {
60 if (region.empty()) {
61 continue;
62 }
63 for (Value argument : region.front().getArguments()) {
65 }
66 }
67
68 return initializeRecursively(top);
69}
70
71LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
72 // Initialize the analysis by visiting every owner of an SSA value (all
73 // operations and blocks).
74 if (failed(visitOperation(op))) {
75 return failure();
76 }
77
78 for (Region &region : op->getRegions()) {
79 for (Block &block : region) {
80 getOrCreate<Executable>(getProgramPointBefore(&block))->blockContentSubscribe(this);
81 visitBlock(&block);
82 // LLZK: Renamed "op" -> "containedOp" to avoid shadowing.
83 for (Operation &containedOp : block) {
84 if (failed(initializeRecursively(&containedOp))) {
85 return failure();
86 }
87 }
88 }
89 }
90
91 return success();
92}
93
94LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
95 if (!point->isBlockStart()) {
96 return visitOperation(point->getPrevOp());
97 }
98 visitBlock(point->getBlock());
99 return success();
100}
101
102LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
106
107 // If the containing block is not executable, bail out.
108 if (op->getBlock() != nullptr &&
109 !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) {
110 return success();
111 }
112
113 // Get the result lattices.
114 SmallVector<AbstractSparseLattice *> resultLattices;
115 resultLattices.reserve(op->getNumResults());
116 for (Value result : op->getResults()) {
117 AbstractSparseLattice *resultLattice = getLatticeElement(result);
118 resultLattices.push_back(resultLattice);
119 }
120
121 // The results of a region branch operation are determined by control-flow.
122 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
123 visitRegionSuccessors(
124 getProgramPointAfter(branch), branch,
125 /*successor=*/RegionBranchPoint::parent(), resultLattices
126 );
127 return success();
128 }
129
130 // Grab the lattice elements of the operands.
131 SmallVector<const AbstractSparseLattice *> operandLattices;
132 operandLattices.reserve(op->getNumOperands());
133 for (Value operand : op->getOperands()) {
134 AbstractSparseLattice *operandLattice = getLatticeElement(operand);
135 operandLattice->useDefSubscribe(this);
136 operandLattices.push_back(operandLattice);
137 }
138
139 // LLZK TODO: Enable for interprocedural analysis.
140 /*
141 if (auto call = dyn_cast<CallOpInterface>(op)) {
143 // If the call operation is to an external function, attempt to infer the
144 // results from the call arguments.
145 auto callable = resolveCallable<FuncDefOp>(tables, call);
146 if (!getSolverConfig().isInterprocedural() ||
147 (succeeded(callable) && !callable->get().getCallableRegion())) {
148 visitExternalCallImpl(call, operandLattices, resultLattices);
149 return success();
150 }
151
152 // Otherwise, the results of a call operation are determined by the
153 // callgraph.
156 SmallVector<Operation *> predecessors;
157 callable->get().walk([&predecessors](ReturnOp ret) mutable { predecessors.push_back(ret); });
158
159 // If not all return sites are known, then conservatively assume we can't
160 // reason about the data-flow.
161 if (predecessors.empty()) {
162 setAllToEntryStates(resultLattices);
163 return success();
164 }
165 for (Operation *predecessor : predecessors) {
166 for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) {
167 join(resLattice, *getLatticeElementFor(getProgramPointAfter(op), operand));
168 }
169 }
170 return success();
171 }
172 */
173
174 // Invoke the operation transfer function.
175 return visitOperationImpl(op, operandLattices, resultLattices);
176}
177
180void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
181 // Exit early on blocks with no arguments.
182 if (block->getNumArguments() == 0) {
183 return;
184 }
185
186 // If the block is not executable, bail out.
187 if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive()) {
188 return;
189 }
190
191 // Get the argument lattices.
192 SmallVector<AbstractSparseLattice *> argLattices;
193 argLattices.reserve(block->getNumArguments());
194 for (BlockArgument argument : block->getArguments()) {
195 AbstractSparseLattice *argLattice = getLatticeElement(argument);
196 argLattices.push_back(argLattice);
197 }
198
199 // The argument lattices of entry blocks are set by region control-flow or the
200 // callgraph.
201 if (block->isEntryBlock()) {
202 // Check if this block is the entry block of a callable region.
203 // LLZK TODO: Enable for interprocedural analysis.
204 /*
205 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
206 if (callable && callable.getCallableRegion() == block->getParent()) {
208 auto moduleOpRes = getTopRootModule(callable.getOperation());
209 ensure(succeeded(moduleOpRes), "could not get root module from callable");
210 SmallVector<Operation *> callsites;
211 moduleOpRes->walk([this, &callable, &callsites](CallOp call) mutable {
212 auto calledFnRes = resolveCallable<FuncDefOp>(tables, call);
213 if (succeeded(calledFnRes) &&
214 calledFnRes->get().getCallableRegion() == callable.getCallableRegion()) {
215 callsites.push_back(call);
216 }
217 });
218 // If not all callsites are known, conservatively mark all lattices as
219 // having reached their pessimistic fixpoints.
220 if (callsites.empty() || !getSolverConfig().isInterprocedural()) {
221 return setAllToEntryStates(argLattices);
222 }
223 for (Operation *callsite : callsites) {
224 auto call = cast<CallOpInterface>(callsite);
225 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) {
226 join(
227 std::get<1>(it), *getLatticeElementFor(getProgramPointBefore(block), std::get<0>(it))
228 );
229 }
230 }
231 return;
232 }
233 */
234
235 // Check if the lattices can be determined from region control flow.
236 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
237 return visitRegionSuccessors(
238 getProgramPointBefore(block), branch, block->getParent(), argLattices
239 );
240 }
241
242 // Otherwise, we can't reason about the data-flow.
244 block->getParentOp(), RegionSuccessor(block->getParent()), argLattices, /*firstIndex=*/0
245 );
246 }
247
248 // Iterate over the predecessors of the non-entry block.
249 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
250 Block *predecessor = *it;
251
252 // If the edge from the predecessor block to the current block is not live,
253 // bail out.
254 auto *edgeExecutable = getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
255 edgeExecutable->blockContentSubscribe(this);
256 if (!edgeExecutable->isLive()) {
257 continue;
258 }
259
260 // Check if we can reason about the data-flow from the predecessor.
261 if (auto branch = dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
262 SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex());
263 for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
264 if (Value operand = operands[idx]) {
265 join(lattice, *getLatticeElementFor(getProgramPointBefore(block), operand));
266 } else {
267 // Conservatively consider internally produced arguments as entry
268 // points.
269 setAllToEntryStates(lattice);
270 }
271 }
272 } else {
273 return setAllToEntryStates(argLattices);
274 }
275 }
276}
277
279void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
280 ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor,
281 ArrayRef<AbstractSparseLattice *> lattices
282) {
283 Operation *op = point->isBlockStart() ? point->getBlock()->getParentOp() : point->getPrevOp();
284
285 if (op) {
286 // Get the incoming successor operands.
287 std::optional<OperandRange> operands;
288
289 // Check if the predecessor is the parent op.
290 if (op == branch) {
291 operands = branch.getEntrySuccessorOperands(successor);
292 // Otherwise, try to deduce the operands from a region return-like op.
293 } else if (auto regionTerminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
294 operands = regionTerminator.getSuccessorOperands(successor);
295 }
296
297 if (!operands) {
298 // We can't reason about the data-flow.
299 return setAllToEntryStates(lattices);
300 }
301
302 ValueRange inputs;
303
305 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
306 inputs = forOp.getRegionIterArgs();
307 } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
308 inputs = whileOp.getRegionIterArgs();
309 }
310
311 if (inputs.size() != operands->size()) {
312 // We can't reason about the data-flow.
313 return setAllToEntryStates(lattices);
314 }
315
316 unsigned firstIndex = 0;
317 if (inputs.size() != lattices.size()) {
318 if (!point->isBlockStart()) {
319 if (!inputs.empty()) {
320 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
321 }
323 branch, RegionSuccessor(branch->getResults().slice(firstIndex, inputs.size())),
324 lattices, firstIndex
325 );
326 } else {
327 if (!inputs.empty()) {
328 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
329 }
330 Region *region = point->getBlock()->getParent();
332 branch,
333 RegionSuccessor(region, region->getArguments().slice(firstIndex, inputs.size())),
334 lattices, firstIndex
335 );
336 }
337 }
338
339 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) {
340 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
341 }
342 }
343}
344
348 addDependency(state, point);
349 return state;
350}
351
353 ArrayRef<AbstractSparseLattice *> lattices
354) {
355 for (AbstractSparseLattice *lattice : lattices) {
356 setToEntryState(lattice);
357 }
358}
359
362) {
363 propagateIfChanged(lhs, lhs->join(rhs));
364}
365
366} // namespace llzk::dataflow
This file implements sparse data-flow analysis using the data-flow analysis framework.
mlir::LogicalResult visit(mlir::ProgramPoint *point) override
Visit a program point.
mlir::LogicalResult initialize(mlir::Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual mlir::LogicalResult visitOperationImpl(mlir::Operation *op, mlir::ArrayRef< const AbstractSparseLattice * > operandLattices, mlir::ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual AbstractSparseLattice * getLatticeElement(mlir::Value value)=0
Get the lattice element of a value.
AbstractSparseForwardDataFlowAnalysis(mlir::DataFlowSolver &solver)
void setAllToEntryStates(mlir::ArrayRef< AbstractSparseLattice * > lattices)
const AbstractSparseLattice * getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(mlir::Operation *op, const mlir::RegionSuccessor &successor, mlir::ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
mlir::dataflow::AbstractSparseLattice AbstractSparseLattice