LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstraintDependencyGraph.cpp
Go to the documentation of this file.
1//===-- ConstraintDependencyGraph.cpp ---------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
16#include "llzk/Util/Hash.h"
18
19#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
20#include <mlir/IR/Value.h>
21
22#include <llvm/Support/Debug.h>
23
24#include <numeric>
25#include <unordered_set>
26
27#define DEBUG_TYPE "llzk-cdg"
28
29namespace llzk {
30
31using namespace array;
32using namespace component;
33using namespace constrain;
34using namespace function;
35
36/* ConstrainRefAnalysis */
37
39 mlir::CallOpInterface call, dataflow::CallControlFlowAction action,
40 const ConstrainRefLattice &before, ConstrainRefLattice *after
41) {
42 LLVM_DEBUG(
43 llvm::dbgs() << "ConstrainRefAnalysis::visitCallControlFlowTransfer: " << call << '\n'
44 );
45 auto fnOpRes = resolveCallable<FuncDefOp>(tables, call);
46 ensure(succeeded(fnOpRes), "could not resolve called function");
47
48 LLVM_DEBUG({
49 llvm::dbgs().indent(4) << "parent op is ";
50 if (auto s = call->getParentOfType<StructDefOp>()) {
51 llvm::dbgs() << s.getName();
52 } else if (auto p = call->getParentOfType<FuncDefOp>()) {
53 llvm::dbgs() << p.getName();
54 } else {
55 llvm::dbgs() << "<UNKNOWN PARENT TYPE>";
56 }
57 llvm::dbgs() << '\n';
58 });
59
63 if (action == dataflow::CallControlFlowAction::EnterCallee) {
64 // We skip updating the incoming lattice for function calls,
65 // as ConstrainRefs are relative to the containing function/struct, so we don't need to pollute
66 // the callee with the callers values.
67 // This also avoids a non-convergence scenario, as calling a
68 // function from other contexts can cause the lattice values to oscillate and constantly
69 // change (thus looping infinitely).
70
71 setToEntryState(after);
72 }
76 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
77 // Get the argument values of the lattice by getting the state as it would
78 // have been for the callsite.
79 dataflow::AbstractDenseLattice *beforeCall = nullptr;
80 if (auto *prev = call->getPrevNode()) {
81 beforeCall = getLattice(prev);
82 } else {
83 beforeCall = getLattice(call->getBlock());
84 }
85 ensure(beforeCall, "could not get prior lattice");
86
87 // Translate argument values based on the operands given at the call site.
88 std::unordered_map<ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash> translation;
89 auto funcOpRes = resolveCallable<FuncDefOp>(tables, call);
90 ensure(mlir::succeeded(funcOpRes), "could not lookup called function");
91 auto funcOp = funcOpRes->get();
92
93 auto callOp = mlir::dyn_cast<CallOp>(call.getOperation());
94 ensure(callOp, "call is not a llzk::CallOp");
95
96 for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
97 auto key = ConstrainRef(funcOp.getArgument(i));
98 auto val = before.getOrDefault(callOp.getOperand(i));
99 translation[key] = val;
100 }
101
102 // The lattice at the return is the lattice before the call + translated
103 // return values.
104 mlir::ChangeResult updated = after->join(*beforeCall);
105 for (unsigned i = 0; i < callOp.getNumResults(); i++) {
106 auto retVal = before.getReturnValue(i);
107 auto [translatedVal, _] = retVal.translate(translation);
108 updated |= after->setValue(callOp->getResult(i), translatedVal);
109 }
110 propagateIfChanged(after, updated);
111 }
116 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
117 // For external calls, we propagate what information we already have from
118 // before the call to after the call, since the external call won't invalidate
119 // any of that information. It also, conservatively, makes no assumptions about
120 // external calls and their computation, so CDG edges will not be computed over
121 // input arguments to external functions.
122 join(after, before);
123 }
124}
125
127 mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after
128) {
129 LLVM_DEBUG(llvm::dbgs() << "ConstrainRefAnalysis::visitOperation: " << *op << '\n');
130 // Collect the references that are made by the operands to `op`.
132 for (auto &operand : op->getOpOperands()) {
133 operandVals[operand.get()] = before.getOrDefault(operand.get());
134 }
135
136 // Propagate existing state.
137 join(after, before);
138
139 // We will now join the the operand refs based on the type of operand.
140 if (auto fieldRead = mlir::dyn_cast<FieldReadOp>(op)) {
141 // In the readf case, the operand is indexed into by the read's fielddefop.
142 assert(operandVals.size() == 1);
143 assert(fieldRead->getNumResults() == 1);
144
145 auto fieldOpRes = fieldRead.getFieldDefOp(tables);
146 ensure(mlir::succeeded(fieldOpRes), "could not find field read");
147
148 auto res = fieldRead->getResult(0);
149 const auto &ops = operandVals.at(fieldRead->getOpOperand(0).get());
150 auto [fieldVals, _] = ops.referenceField(fieldOpRes.value());
151
152 propagateIfChanged(after, after->setValue(res, fieldVals));
153 } else if (mlir::isa<ReadArrayOp>(op)) {
154 arraySubdivisionOpUpdate(op, operandVals, before, after);
155 } else if (auto createArray = mlir::dyn_cast<CreateArrayOp>(op)) {
156 // Create an array using the operand values, if they exist.
157 // Currently, the new array must either be fully initialized or uninitialized.
158
159 auto newArrayVal = ConstrainRefLatticeValue(createArray.getType().getShape());
160 // If the array is initialized, iterate through all operands and initialize the array value.
161 for (unsigned i = 0; i < createArray.getNumOperands(); i++) {
162 auto currentOp = createArray.getOperand(i);
163 auto &opVals = operandVals[currentOp];
164 (void)newArrayVal.getElemFlatIdx(i).setValue(opVals);
165 }
166
167 assert(createArray->getNumResults() == 1);
168 auto res = createArray->getResult(0);
169
170 propagateIfChanged(after, after->setValue(res, newArrayVal));
171 } else if (auto extractArray = mlir::dyn_cast<ExtractArrayOp>(op)) {
172 arraySubdivisionOpUpdate(op, operandVals, before, after);
173 } else {
174 // Standard union of operands into the results value.
175 // TODO: Could perform constant computation/propagation here for, e.g., arithmetic
176 // over constants, but such analysis may be better suited for a dedicated pass.
177 propagateIfChanged(after, fallbackOpUpdate(op, operandVals, before, after));
178 }
179}
180
181// Perform a standard union of operands into the results value.
183 mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals,
184 const ConstrainRefLattice &before, ConstrainRefLattice *after
185) {
186 auto updated = mlir::ChangeResult::NoChange;
187 for (auto res : op->getResults()) {
188 auto cur = before.getOrDefault(res);
189
190 for (auto &[_, opVal] : operandVals) {
191 (void)cur.update(opVal);
192 }
193 updated |= after->setValue(res, cur);
194 }
195 return updated;
196}
197
198// Perform the update for either a readarr op or an extractarr op, which
199// operate very similarly: index into the first operand using a variable number
200// of provided indices.
202 mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals,
203 const ConstrainRefLattice &before, ConstrainRefLattice *after
204) {
205 ensure(mlir::isa<ReadArrayOp, ExtractArrayOp>(op), "wrong type of op provided!");
206
207 // We index the first operand by all remaining indices.
208 assert(op->getNumResults() == 1);
209 auto res = op->getResult(0);
210
211 auto array = op->getOperand(0);
212 auto it = operandVals.find(array);
213 ensure(it != operandVals.end(), "improperly constructed operandVals map");
214 auto currVals = it->second;
215
216 std::vector<ConstrainRefIndex> indices;
217
218 for (size_t i = 1; i < op->getNumOperands(); i++) {
219 auto currentOp = op->getOperand(i);
220 auto idxIt = operandVals.find(currentOp);
221 ensure(idxIt != operandVals.end(), "improperly constructed operandVals map");
222 auto &idxVals = idxIt->second;
223
224 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstantIndex()) {
225 ConstrainRefIndex idx(idxVals.getSingleValue().getConstantIndexValue());
226 indices.push_back(idx);
227 } else {
228 // Otherwise, assume any range is valid.
229 auto arrayType = mlir::dyn_cast<ArrayType>(array.getType());
230 auto lower = mlir::APInt::getZero(64);
231 mlir::APInt upper(64, arrayType.getDimSize(i - 1));
232 auto idxRange = ConstrainRefIndex(lower, upper);
233 indices.push_back(idxRange);
234 }
235 }
236
237 auto [newVals, _] = currVals.extract(indices);
238
239 propagateIfChanged(after, after->setValue(res, newVals));
240}
241
242/* ConstraintDependencyGraph */
243
244mlir::FailureOr<ConstraintDependencyGraph> ConstraintDependencyGraph::compute(
245 mlir::ModuleOp m, StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
246) {
248 if (cdg.computeConstraints(solver, am).failed()) {
249 return mlir::failure();
250 }
251 return cdg;
252}
253
254void ConstraintDependencyGraph::dump() const { print(llvm::errs()); }
255
257void ConstraintDependencyGraph::print(llvm::raw_ostream &os) const {
258 // the EquivalenceClasses::iterator is sorted, but the EquivalenceClasses::member_iterator is
259 // not guaranteed to be sorted. So, we will sort members before printing them.
260 // We also want to add the constant values into the printing.
261 std::set<std::set<ConstrainRef>> sortedSets;
262 for (auto it = signalSets.begin(); it != signalSets.end(); it++) {
263 if (!it->isLeader()) {
264 continue;
265 }
266
267 std::set<ConstrainRef> sortedMembers;
268 for (auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
269 sortedMembers.insert(*mit);
270 }
271
272 // We only want to print sets with a size > 1, because size == 1 means the
273 // signal is not in a constraint.
274 if (sortedMembers.size() > 1) {
275 sortedSets.insert(sortedMembers);
276 }
277 }
278 // Add the constants in separately.
279 for (auto &[ref, constSet] : constantSets) {
280 if (constSet.empty()) {
281 continue;
282 }
283 std::set<ConstrainRef> sortedMembers(constSet.begin(), constSet.end());
284 sortedMembers.insert(ref);
285 sortedSets.insert(sortedMembers);
286 }
287
288 os << "ConstraintDependencyGraph { ";
289
290 for (auto it = sortedSets.begin(); it != sortedSets.end();) {
291 os << "\n { ";
292 for (auto mit = it->begin(); mit != it->end();) {
293 os << *mit;
294 mit++;
295 if (mit != it->end()) {
296 os << ", ";
297 }
298 }
299
300 it++;
301 if (it == sortedSets.end()) {
302 os << " }\n";
303 } else {
304 os << " },";
305 }
306 }
307
308 os << "}\n";
309}
310
311mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
312 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
313) {
314 // Fetch the constrain function. This is a required feature for all LLZK structs.
315 auto constrainFnOp = structDef.getConstrainFuncOp();
316 ensure(
317 constrainFnOp,
318 "malformed struct " + mlir::Twine(structDef.getName()) + " must define a constrain function"
319 );
320
326
327 // - Union all constraints from the analysis
328 // This requires iterating over all of the emit operations
329 constrainFnOp.walk([this, &solver](EmitEqualityOp emitOp) {
330 this->walkConstrainOp(solver, emitOp);
331 });
332
333 constrainFnOp.walk([this, &solver](EmitContainmentOp emitOp) {
334 this->walkConstrainOp(solver, emitOp);
335 });
336
344 constrainFnOp.walk([this, &solver, &am](CallOp fnCall) mutable {
345 auto res = resolveCallable<FuncDefOp>(tables, fnCall);
346 ensure(mlir::succeeded(res), "could not resolve constrain call");
347
348 auto fn = res->get();
349 if (!fn.isStructConstrain()) {
350 return;
351 }
352 // Nested
353 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
354 ConstrainRefRemappings translations;
355
356 auto lattice = solver.lookupState<ConstrainRefLattice>(fnCall.getOperation());
357 ensure(lattice, "could not find lattice for call operation");
358
359 // Map fn parameters to args in the call op
360 for (unsigned i = 0; i < fn.getNumArguments(); i++) {
361 auto prefix = ConstrainRef(fn.getArgument(i));
362 auto val = lattice->getOrDefault(fnCall.getOperand(i));
363 translations.push_back({prefix, val});
364 }
365 auto &childAnalysis =
366 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
367 if (!childAnalysis.constructed()) {
368 ensure(
369 mlir::succeeded(childAnalysis.runAnalysis(solver, am)),
370 "could not construct CDG for child struct"
371 );
372 }
373 auto translatedCDG = childAnalysis.getResult().translate(translations);
374
375 // Now, union sets based on the translation
376 // We should be able to just merge what is in the translatedCDG to the current CDG
377 auto &tSets = translatedCDG.signalSets;
378 for (auto lit = tSets.begin(); lit != tSets.end(); lit++) {
379 if (!lit->isLeader()) {
380 continue;
381 }
382 auto leader = lit->getData();
383 for (auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
384 signalSets.unionSets(leader, *mit);
385 }
386 }
387 // And update the constant sets
388 for (auto &[ref, constSet] : translatedCDG.constantSets) {
389 constantSets[ref].insert(constSet.begin(), constSet.end());
390 }
391 });
392
393 return mlir::success();
394}
395
396void ConstraintDependencyGraph::walkConstrainOp(
397 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
398) {
399 std::vector<ConstrainRef> signalUsages, constUsages;
400 auto lattice = solver.lookupState<ConstrainRefLattice>(emitOp);
401 ensure(lattice, "failed to get lattice for emit operation");
402
403 for (auto operand : emitOp->getOperands()) {
404 auto latticeVal = lattice->getOrDefault(operand);
405 for (auto &ref : latticeVal.foldToScalar()) {
406 if (ref.isConstant()) {
407 constUsages.push_back(ref);
408 } else {
409 signalUsages.push_back(ref);
410 }
411 }
412 }
413
414 // Compute a transitive closure over the signals.
415 if (!signalUsages.empty()) {
416 auto it = signalUsages.begin();
417 auto leader = signalSets.getOrInsertLeaderValue(*it);
418 for (it++; it != signalUsages.end(); it++) {
419 signalSets.unionSets(leader, *it);
420 }
421 }
422 // Also update constant references for each value.
423 for (auto &sig : signalUsages) {
424 constantSets[sig].insert(constUsages.begin(), constUsages.end());
425 }
426}
427
429) const {
430 ConstraintDependencyGraph res(mod, structDef);
431 auto translate = [&translation](const ConstrainRef &elem
432 ) -> mlir::FailureOr<std::vector<ConstrainRef>> {
433 std::vector<ConstrainRef> refs;
434 for (auto &[prefix, vals] : translation) {
435 if (!elem.isValidPrefix(prefix)) {
436 continue;
437 }
438
439 if (vals.isArray()) {
440 // Try to index into the array
441 auto suffix = elem.getSuffix(prefix);
442 ensure(
443 mlir::succeeded(suffix), "failure is nonsensical, we already checked for valid prefix"
444 );
445
446 auto [resolvedVals, _] = vals.extract(suffix.value());
447 auto folded = resolvedVals.foldToScalar();
448 refs.insert(refs.end(), folded.begin(), folded.end());
449 } else {
450 for (auto &replacement : vals.getScalarValue()) {
451 auto translated = elem.translate(prefix, replacement);
452 if (mlir::succeeded(translated)) {
453 refs.push_back(translated.value());
454 }
455 }
456 }
457 }
458 if (refs.empty()) {
459 return mlir::failure();
460 }
461 return refs;
462 };
463
464 for (auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
465 if (!leaderIt->isLeader()) {
466 continue;
467 }
468 // translate everything in this set first
469 std::vector<ConstrainRef> translatedSignals, translatedConsts;
470 for (auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
471 auto member = translate(*mit);
472 if (mlir::failed(member)) {
473 continue;
474 }
475 for (auto &ref : *member) {
476 if (ref.isConstant()) {
477 translatedConsts.push_back(ref);
478 } else {
479 translatedSignals.push_back(ref);
480 }
481 }
482 // Also add the constants from the original CDG
483 if (auto it = constantSets.find(*mit); it != constantSets.end()) {
484 auto &origConstSet = it->second;
485 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
486 }
487 }
488
489 if (translatedSignals.empty()) {
490 continue;
491 }
492
493 // Now we can insert the translated signals
494 auto it = translatedSignals.begin();
495 auto leader = *it;
496 res.signalSets.insert(leader);
497 for (it++; it != translatedSignals.end(); it++) {
498 res.signalSets.insert(*it);
499 res.signalSets.unionSets(leader, *it);
500 }
501
502 // And update the constant references
503 for (auto &ref : translatedSignals) {
504 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
505 }
506 }
507 return res;
508}
509
511 ConstrainRefSet res;
512 auto currRef = mlir::FailureOr<ConstrainRef>(ref);
513 while (mlir::succeeded(currRef)) {
514 // Add signals
515 for (auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
516 if (currRef.value() != *it) {
517 res.insert(*it);
518 }
519 }
520 // Add constants
521 auto constIt = constantSets.find(*currRef);
522 if (constIt != constantSets.end()) {
523 res.insert(constIt->second.begin(), constIt->second.end());
524 }
525 // Go to parent
526 currRef = currRef->getParentPrefix();
527 }
528 return res;
529}
530
531/* ConstraintDependencyGraphStructAnalysis */
532
534 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager
535) {
536 auto result =
537 ConstraintDependencyGraph::compute(getModule(), getStruct(), solver, moduleAnalysisManager);
538 if (mlir::failed(result)) {
539 return mlir::failure();
540 }
541 setResult(std::move(*result));
542 return mlir::success();
543}
544
545} // namespace llzk
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Hook for customizing the behavior of lattice propagation along the call control flow edges.
mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void arraySubdivisionOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void setToEntryState(ConstrainRefLattice *lattice) override
Set the dense lattice at control flow entry point and propagate an update if it changed.
void visitOperation(mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Propagate constrain reference lattice values from operands to results.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
A lattice for use in dense analysis.
mlir::ChangeResult join(const AbstractDenseLattice &rhs) override
Maximum upper bound.
ConstrainRefLatticeValue getOrDefault(mlir::Value v) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseMap< mlir::Value, ConstrainRefLatticeValue > ValueMap
mlir::ChangeResult setValue(mlir::Value v, const ConstrainRefLatticeValue &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, NoContext &_) override
Perform the analysis and construct the Result output.
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am)
Compute a ConstraintDependencyGraph (CDG)
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph translate(ConstrainRefRemappings translation) const
Translate the ConstrainRefs in this CDG to that of a different context.
ConstrainRefSet getConstrainingValues(const ConstrainRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
void dump() const
Dumps the CDG to stderr.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
ConstrainRefLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
std::vector< std::pair< ConstrainRef, ConstrainRefLatticeValue > > ConstrainRefRemappings
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.