LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstraintDependencyGraph.cpp
Go to the documentation of this file.
1//===-- ConstraintDependencyGraph.cpp ---------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
16#include "llzk/Util/Hash.h"
18
19#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
20#include <mlir/IR/Value.h>
21
22#include <llvm/Support/Debug.h>
23
24#include <numeric>
25#include <unordered_set>
26
27#define DEBUG_TYPE "llzk-cdg"
28
29namespace llzk {
30
31using namespace array;
32using namespace component;
33using namespace constrain;
34using namespace function;
35
36/* ConstrainRefAnalysis */
37
39 mlir::CallOpInterface call, dataflow::CallControlFlowAction action,
40 const ConstrainRefLattice &before, ConstrainRefLattice *after
41) {
42 LLVM_DEBUG(
43 llvm::dbgs() << "ConstrainRefAnalysis::visitCallControlFlowTransfer: " << call << '\n'
44 );
45 auto fnOpRes = resolveCallable<FuncDefOp>(tables, call);
46 ensure(succeeded(fnOpRes), "could not resolve called function");
47
48 LLVM_DEBUG({
49 llvm::dbgs().indent(4) << "parent op is ";
50 if (auto s = call->getParentOfType<StructDefOp>()) {
51 llvm::dbgs() << s.getName();
52 } else if (auto p = call->getParentOfType<FuncDefOp>()) {
53 llvm::dbgs() << p.getName();
54 } else {
55 llvm::dbgs() << "<UNKNOWN PARENT TYPE>";
56 }
57 llvm::dbgs() << '\n';
58 });
59
63 if (action == dataflow::CallControlFlowAction::EnterCallee) {
64 // We skip updating the incoming lattice for function calls,
65 // as ConstrainRefs are relative to the containing function/struct, so we don't need to pollute
66 // the callee with the callers values.
67 // This also avoids a non-convergence scenario, as calling a
68 // function from other contexts can cause the lattice values to oscillate and constantly
69 // change (thus looping infinitely).
70
71 setToEntryState(after);
72 }
76 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
77 // Get the argument values of the lattice by getting the state as it would
78 // have been for the callsite.
79 dataflow::AbstractDenseLattice *beforeCall = nullptr;
80 if (auto *prev = call->getPrevNode()) {
81 beforeCall = getLattice(prev);
82 } else {
83 beforeCall = getLattice(call->getBlock());
84 }
85 ensure(beforeCall, "could not get prior lattice");
86
87 // Translate argument values based on the operands given at the call site.
88 std::unordered_map<ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash> translation;
89 auto funcOpRes = resolveCallable<FuncDefOp>(tables, call);
90 ensure(mlir::succeeded(funcOpRes), "could not lookup called function");
91 auto funcOp = funcOpRes->get();
92
93 auto callOp = mlir::dyn_cast<CallOp>(call.getOperation());
94 ensure(callOp, "call is not a llzk::CallOp");
95
96 for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
97 auto key = ConstrainRef(funcOp.getArgument(i));
98 auto val = before.getOrDefault(callOp.getOperand(i));
99 translation[key] = val;
100 }
101
102 // The lattice at the return is the lattice before the call + translated
103 // return values.
104 mlir::ChangeResult updated = after->join(*beforeCall);
105 for (unsigned i = 0; i < callOp.getNumResults(); i++) {
106 auto retVal = before.getReturnValue(i);
107 auto [translatedVal, _] = retVal.translate(translation);
108 updated |= after->setValue(callOp->getResult(i), translatedVal);
109 }
110 propagateIfChanged(after, updated);
111 }
116 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
117 // For external calls, we propagate what information we already have from
118 // before the call to after the call, since the external call won't invalidate
119 // any of that information. It also, conservatively, makes no assumptions about
120 // external calls and their computation, so CDG edges will not be computed over
121 // input arguments to external functions.
122 join(after, before);
123 }
124}
125
127 mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after
128) {
129 LLVM_DEBUG(llvm::dbgs() << "ConstrainRefAnalysis::visitOperation: " << *op << '\n');
130 // Collect the references that are made by the operands to `op`.
132 for (mlir::OpOperand &operand : op->getOpOperands()) {
133 operandVals[operand.get()] = before.getOrDefault(operand.get());
134 }
135
136 // Propagate existing state.
137 join(after, before);
138 // Add operand values, if not already added. Ensures that the default value
139 // of a ConstrainRef (the source of the ref) is visible in the lattice.
140 propagateIfChanged(after, after->setValues(operandVals));
141
142 // We will now join the the operand refs based on the type of operand.
143 if (auto fieldRead = mlir::dyn_cast<FieldReadOp>(op)) {
144 // In the readf case, the operand is indexed into by the read's fielddefop.
145 assert(operandVals.size() == 1);
146 assert(fieldRead->getNumResults() == 1);
147
148 auto fieldOpRes = fieldRead.getFieldDefOp(tables);
149 ensure(mlir::succeeded(fieldOpRes), "could not find field read");
150
151 auto res = fieldRead->getResult(0);
152 const auto &ops = operandVals.at(fieldRead->getOpOperand(0).get());
153 auto [fieldVals, _] = ops.referenceField(fieldOpRes.value());
154
155 propagateIfChanged(after, after->setValue(res, fieldVals));
156 } else if (mlir::isa<ReadArrayOp>(op)) {
157 arraySubdivisionOpUpdate(op, operandVals, before, after);
158 } else if (auto createArray = mlir::dyn_cast<CreateArrayOp>(op)) {
159 // Create an array using the operand values, if they exist.
160 // Currently, the new array must either be fully initialized or uninitialized.
161
162 auto newArrayVal = ConstrainRefLatticeValue(createArray.getType().getShape());
163 // If the array is initialized, iterate through all operands and initialize the array value.
164 for (unsigned i = 0; i < createArray.getNumOperands(); i++) {
165 auto currentOp = createArray.getOperand(i);
166 auto &opVals = operandVals[currentOp];
167 (void)newArrayVal.getElemFlatIdx(i).setValue(opVals);
168 }
169
170 assert(createArray->getNumResults() == 1);
171 auto res = createArray->getResult(0);
172
173 propagateIfChanged(after, after->setValue(res, newArrayVal));
174 } else if (auto extractArray = mlir::dyn_cast<ExtractArrayOp>(op)) {
175 arraySubdivisionOpUpdate(op, operandVals, before, after);
176 } else {
177 // Standard union of operands into the results value.
178 // TODO: Could perform constant computation/propagation here for, e.g., arithmetic
179 // over constants, but such analysis may be better suited for a dedicated pass.
180 propagateIfChanged(after, fallbackOpUpdate(op, operandVals, before, after));
181 }
182}
183
184// Perform a standard union of operands into the results value.
186 mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals,
187 const ConstrainRefLattice &before, ConstrainRefLattice *after
188) {
189 auto updated = mlir::ChangeResult::NoChange;
190 for (auto res : op->getResults()) {
191 auto cur = before.getOrDefault(res);
192
193 for (auto &[_, opVal] : operandVals) {
194 (void)cur.update(opVal);
195 }
196 updated |= after->setValue(res, cur);
197 }
198 return updated;
199}
200
201// Perform the update for either a readarr op or an extractarr op, which
202// operate very similarly: index into the first operand using a variable number
203// of provided indices.
205 mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals,
206 const ConstrainRefLattice &before, ConstrainRefLattice *after
207) {
208 ensure(mlir::isa<ReadArrayOp, ExtractArrayOp>(op), "wrong type of op provided!");
209
210 // We index the first operand by all remaining indices.
211 assert(op->getNumResults() == 1);
212 auto res = op->getResult(0);
213
214 auto array = op->getOperand(0);
215 auto it = operandVals.find(array);
216 ensure(it != operandVals.end(), "improperly constructed operandVals map");
217 auto currVals = it->second;
218
219 std::vector<ConstrainRefIndex> indices;
220
221 for (size_t i = 1; i < op->getNumOperands(); i++) {
222 auto currentOp = op->getOperand(i);
223 auto idxIt = operandVals.find(currentOp);
224 ensure(idxIt != operandVals.end(), "improperly constructed operandVals map");
225 auto &idxVals = idxIt->second;
226
227 // Note: we allow constant values regardless of if they are felt or index,
228 // as if they were felt, there would need to be a cast to index, and if it
229 // was missing, there would be a semantic check failure. So we accept either
230 // so we don't have to track the cast ourselves.
231 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
232 ConstrainRefIndex idx(idxVals.getSingleValue().getConstantValue());
233 indices.push_back(idx);
234 } else {
235 // Otherwise, assume any range is valid.
236 auto arrayType = mlir::dyn_cast<ArrayType>(array.getType());
237 auto lower = mlir::APInt::getZero(64);
238 mlir::APInt upper(64, arrayType.getDimSize(i - 1));
239 auto idxRange = ConstrainRefIndex(lower, upper);
240 indices.push_back(idxRange);
241 }
242 }
243
244 auto [newVals, _] = currVals.extract(indices);
245
246 if (mlir::isa<ReadArrayOp>(op)) {
247 ensure(newVals.isScalar(), "array read must produce a scalar value");
248 }
249 // an extract operation may yield a "scalar" value if not all dimensions of
250 // the source array are instantiated; for example, if extracting an array from
251 // an input arg, the current value is a "scalar" with an array type, and extracting
252 // from that yields another single value with indices. For example: extracting [0][1]
253 // from { arg1 } yields { arg1[0][1] }.
254
255 propagateIfChanged(after, after->setValue(res, newVals));
256}
257
258/* ConstraintDependencyGraph */
259
260mlir::FailureOr<ConstraintDependencyGraph> ConstraintDependencyGraph::compute(
261 mlir::ModuleOp m, StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
262) {
264 if (cdg.computeConstraints(solver, am).failed()) {
265 return mlir::failure();
266 }
267 return cdg;
268}
269
270void ConstraintDependencyGraph::dump() const { print(llvm::errs()); }
271
273void ConstraintDependencyGraph::print(llvm::raw_ostream &os) const {
274 // the EquivalenceClasses::iterator is sorted, but the EquivalenceClasses::member_iterator is
275 // not guaranteed to be sorted. So, we will sort members before printing them.
276 // We also want to add the constant values into the printing.
277 std::set<std::set<ConstrainRef>> sortedSets;
278 for (auto it = signalSets.begin(); it != signalSets.end(); it++) {
279 if (!it->isLeader()) {
280 continue;
281 }
282
283 std::set<ConstrainRef> sortedMembers;
284 for (auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
285 sortedMembers.insert(*mit);
286 }
287
288 // We only want to print sets with a size > 1, because size == 1 means the
289 // signal is not in a constraint.
290 if (sortedMembers.size() > 1) {
291 sortedSets.insert(sortedMembers);
292 }
293 }
294 // Add the constants in separately.
295 for (auto &[ref, constSet] : constantSets) {
296 if (constSet.empty()) {
297 continue;
298 }
299 std::set<ConstrainRef> sortedMembers(constSet.begin(), constSet.end());
300 sortedMembers.insert(ref);
301 sortedSets.insert(sortedMembers);
302 }
303
304 os << "ConstraintDependencyGraph { ";
305
306 for (auto it = sortedSets.begin(); it != sortedSets.end();) {
307 os << "\n { ";
308 for (auto mit = it->begin(); mit != it->end();) {
309 os << *mit;
310 mit++;
311 if (mit != it->end()) {
312 os << ", ";
313 }
314 }
315
316 it++;
317 if (it == sortedSets.end()) {
318 os << " }\n";
319 } else {
320 os << " },";
321 }
322 }
323
324 os << "}\n";
325}
326
327mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
328 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
329) {
330 // Fetch the constrain function. This is a required feature for all LLZK structs.
331 auto constrainFnOp = structDef.getConstrainFuncOp();
332 ensure(
333 constrainFnOp,
334 "malformed struct " + mlir::Twine(structDef.getName()) + " must define a constrain function"
335 );
336
342
343 // - Union all constraints from the analysis
344 // This requires iterating over all of the emit operations
345 constrainFnOp.walk([this, &solver](EmitEqualityOp emitOp) {
346 this->walkConstrainOp(solver, emitOp);
347 });
348
349 constrainFnOp.walk([this, &solver](EmitContainmentOp emitOp) {
350 this->walkConstrainOp(solver, emitOp);
351 });
352
360 constrainFnOp.walk([this, &solver, &am](CallOp fnCall) mutable {
361 auto res = resolveCallable<FuncDefOp>(tables, fnCall);
362 ensure(mlir::succeeded(res), "could not resolve constrain call");
363
364 auto fn = res->get();
365 if (!fn.isStructConstrain()) {
366 return;
367 }
368 // Nested
369 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
370 ConstrainRefRemappings translations;
371
372 auto lattice = solver.lookupState<ConstrainRefLattice>(fnCall.getOperation());
373 ensure(lattice, "could not find lattice for call operation");
374
375 // Map fn parameters to args in the call op
376 for (unsigned i = 0; i < fn.getNumArguments(); i++) {
377 auto prefix = ConstrainRef(fn.getArgument(i));
378 auto val = lattice->getOrDefault(fnCall.getOperand(i));
379 translations.push_back({prefix, val});
380 }
381 auto &childAnalysis =
382 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
383 if (!childAnalysis.constructed()) {
384 ensure(
385 mlir::succeeded(childAnalysis.runAnalysis(solver, am)),
386 "could not construct CDG for child struct"
387 );
388 }
389 auto translatedCDG = childAnalysis.getResult().translate(translations);
390
391 // Now, union sets based on the translation
392 // We should be able to just merge what is in the translatedCDG to the current CDG
393 auto &tSets = translatedCDG.signalSets;
394 for (auto lit = tSets.begin(); lit != tSets.end(); lit++) {
395 if (!lit->isLeader()) {
396 continue;
397 }
398 auto leader = lit->getData();
399 for (auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
400 signalSets.unionSets(leader, *mit);
401 }
402 }
403 // And update the constant sets
404 for (auto &[ref, constSet] : translatedCDG.constantSets) {
405 constantSets[ref].insert(constSet.begin(), constSet.end());
406 }
407 });
408
409 return mlir::success();
410}
411
412void ConstraintDependencyGraph::walkConstrainOp(
413 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
414) {
415 std::vector<ConstrainRef> signalUsages, constUsages;
416 auto lattice = solver.lookupState<ConstrainRefLattice>(emitOp);
417 ensure(lattice, "failed to get lattice for emit operation");
418
419 for (auto operand : emitOp->getOperands()) {
420 auto latticeVal = lattice->getOrDefault(operand);
421 for (auto &ref : latticeVal.foldToScalar()) {
422 if (ref.isConstant()) {
423 constUsages.push_back(ref);
424 } else {
425 signalUsages.push_back(ref);
426 }
427 }
428 }
429
430 // Compute a transitive closure over the signals.
431 if (!signalUsages.empty()) {
432 auto it = signalUsages.begin();
433 auto leader = signalSets.getOrInsertLeaderValue(*it);
434 for (it++; it != signalUsages.end(); it++) {
435 signalSets.unionSets(leader, *it);
436 }
437 }
438 // Also update constant references for each value.
439 for (auto &sig : signalUsages) {
440 constantSets[sig].insert(constUsages.begin(), constUsages.end());
441 }
442}
443
445) const {
446 ConstraintDependencyGraph res(mod, structDef);
447 auto translate = [&translation](const ConstrainRef &elem
448 ) -> mlir::FailureOr<std::vector<ConstrainRef>> {
449 std::vector<ConstrainRef> refs;
450 for (auto &[prefix, vals] : translation) {
451 if (!elem.isValidPrefix(prefix)) {
452 continue;
453 }
454
455 if (vals.isArray()) {
456 // Try to index into the array
457 auto suffix = elem.getSuffix(prefix);
458 ensure(
459 mlir::succeeded(suffix), "failure is nonsensical, we already checked for valid prefix"
460 );
461
462 auto [resolvedVals, _] = vals.extract(suffix.value());
463 auto folded = resolvedVals.foldToScalar();
464 refs.insert(refs.end(), folded.begin(), folded.end());
465 } else {
466 for (auto &replacement : vals.getScalarValue()) {
467 auto translated = elem.translate(prefix, replacement);
468 if (mlir::succeeded(translated)) {
469 refs.push_back(translated.value());
470 }
471 }
472 }
473 }
474 if (refs.empty()) {
475 return mlir::failure();
476 }
477 return refs;
478 };
479
480 for (auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
481 if (!leaderIt->isLeader()) {
482 continue;
483 }
484 // translate everything in this set first
485 std::vector<ConstrainRef> translatedSignals, translatedConsts;
486 for (auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
487 auto member = translate(*mit);
488 if (mlir::failed(member)) {
489 continue;
490 }
491 for (auto &ref : *member) {
492 if (ref.isConstant()) {
493 translatedConsts.push_back(ref);
494 } else {
495 translatedSignals.push_back(ref);
496 }
497 }
498 // Also add the constants from the original CDG
499 if (auto it = constantSets.find(*mit); it != constantSets.end()) {
500 auto &origConstSet = it->second;
501 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
502 }
503 }
504
505 if (translatedSignals.empty()) {
506 continue;
507 }
508
509 // Now we can insert the translated signals
510 auto it = translatedSignals.begin();
511 auto leader = *it;
512 res.signalSets.insert(leader);
513 for (it++; it != translatedSignals.end(); it++) {
514 res.signalSets.insert(*it);
515 res.signalSets.unionSets(leader, *it);
516 }
517
518 // And update the constant references
519 for (auto &ref : translatedSignals) {
520 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
521 }
522 }
523 return res;
524}
525
527 ConstrainRefSet res;
528 auto currRef = mlir::FailureOr<ConstrainRef>(ref);
529 while (mlir::succeeded(currRef)) {
530 // Add signals
531 for (auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
532 if (currRef.value() != *it) {
533 res.insert(*it);
534 }
535 }
536 // Add constants
537 auto constIt = constantSets.find(*currRef);
538 if (constIt != constantSets.end()) {
539 res.insert(constIt->second.begin(), constIt->second.end());
540 }
541 // Go to parent
542 currRef = currRef->getParentPrefix();
543 }
544 return res;
545}
546
547/* ConstraintDependencyGraphStructAnalysis */
548
550 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager
551) {
552 auto result =
553 ConstraintDependencyGraph::compute(getModule(), getStruct(), solver, moduleAnalysisManager);
554 if (mlir::failed(result)) {
555 return mlir::failure();
556 }
557 setResult(std::move(*result));
558 return mlir::success();
559}
560
561} // namespace llzk
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Hook for customizing the behavior of lattice propagation along the call control flow edges.
mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void arraySubdivisionOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void setToEntryState(ConstrainRefLattice *lattice) override
Set the dense lattice at control flow entry point and propagate an update if it changed.
void visitOperation(mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Propagate constrain reference lattice values from operands to results.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
A lattice for use in dense analysis.
mlir::ChangeResult join(const AbstractDenseLattice &rhs) override
Maximum upper bound.
ConstrainRefLatticeValue getOrDefault(mlir::Value v) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseMap< mlir::Value, ConstrainRefLatticeValue > ValueMap
mlir::ChangeResult setValues(const ValueMap &rhs)
mlir::ChangeResult setValue(mlir::Value v, const ConstrainRefLatticeValue &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, NoContext &_) override
Perform the analysis and construct the Result output.
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am)
Compute a ConstraintDependencyGraph (CDG)
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph translate(ConstrainRefRemappings translation) const
Translate the ConstrainRefs in this CDG to that of a different context.
ConstrainRefSet getConstrainingValues(const ConstrainRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
void dump() const
Dumps the CDG to stderr.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
ConstrainRefLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
std::vector< std::pair< ConstrainRef, ConstrainRefLatticeValue > > ConstrainRefRemappings
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.