19#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
20#include <mlir/IR/Value.h>
22#include <llvm/Support/Debug.h>
25#include <unordered_set>
27#define DEBUG_TYPE "llzk-cdg"
32using namespace component;
33using namespace constrain;
34using namespace function;
43 llvm::dbgs() <<
"ConstrainRefAnalysis::visitCallControlFlowTransfer: " << call <<
'\n'
46 ensure(succeeded(fnOpRes),
"could not resolve called function");
49 llvm::dbgs().indent(4) <<
"parent op is ";
50 if (
auto s = call->getParentOfType<
StructDefOp>()) {
51 llvm::dbgs() << s.getName();
52 }
else if (
auto p = call->getParentOfType<
FuncDefOp>()) {
53 llvm::dbgs() << p.getName();
55 llvm::dbgs() <<
"<UNKNOWN PARENT TYPE>";
63 if (action == dataflow::CallControlFlowAction::EnterCallee) {
76 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
80 if (
auto *prev = call->getPrevNode()) {
85 ensure(beforeCall,
"could not get prior lattice");
88 std::unordered_map<ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash> translation;
90 ensure(mlir::succeeded(funcOpRes),
"could not lookup called function");
91 auto funcOp = funcOpRes->get();
93 auto callOp = mlir::dyn_cast<CallOp>(call.getOperation());
94 ensure(callOp,
"call is not a llzk::CallOp");
96 for (
unsigned i = 0; i < funcOp.getNumArguments(); i++) {
99 translation[key] = val;
104 mlir::ChangeResult updated = after->
join(*beforeCall);
105 for (
unsigned i = 0; i < callOp.getNumResults(); i++) {
107 auto [translatedVal, _] = retVal.
translate(translation);
108 updated |= after->
setValue(callOp->getResult(i), translatedVal);
110 propagateIfChanged(after, updated);
116 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
129 LLVM_DEBUG(llvm::dbgs() <<
"ConstrainRefAnalysis::visitOperation: " << *op <<
'\n');
132 for (mlir::OpOperand &operand : op->getOpOperands()) {
133 operandVals[operand.get()] = before.
getOrDefault(operand.get());
140 propagateIfChanged(after, after->
setValues(operandVals));
143 if (
auto fieldRead = mlir::dyn_cast<FieldReadOp>(op)) {
145 assert(operandVals.size() == 1);
146 assert(fieldRead->getNumResults() == 1);
148 auto fieldOpRes = fieldRead.getFieldDefOp(tables);
149 ensure(mlir::succeeded(fieldOpRes),
"could not find field read");
151 auto res = fieldRead->getResult(0);
152 const auto &ops = operandVals.at(fieldRead->getOpOperand(0).get());
153 auto [fieldVals, _] = ops.referenceField(fieldOpRes.value());
155 propagateIfChanged(after, after->
setValue(res, fieldVals));
156 }
else if (mlir::isa<ReadArrayOp>(op)) {
158 }
else if (
auto createArray = mlir::dyn_cast<CreateArrayOp>(op)) {
164 for (
unsigned i = 0; i < createArray.getNumOperands(); i++) {
165 auto currentOp = createArray.getOperand(i);
166 auto &opVals = operandVals[currentOp];
167 (void)newArrayVal.getElemFlatIdx(i).setValue(opVals);
170 assert(createArray->getNumResults() == 1);
171 auto res = createArray->getResult(0);
173 propagateIfChanged(after, after->
setValue(res, newArrayVal));
174 }
else if (
auto extractArray = mlir::dyn_cast<ExtractArrayOp>(op)) {
180 propagateIfChanged(after,
fallbackOpUpdate(op, operandVals, before, after));
189 auto updated = mlir::ChangeResult::NoChange;
190 for (
auto res : op->getResults()) {
193 for (
auto &[_, opVal] : operandVals) {
194 (void)cur.update(opVal);
196 updated |= after->
setValue(res, cur);
208 ensure(mlir::isa<ReadArrayOp, ExtractArrayOp>(op),
"wrong type of op provided!");
211 assert(op->getNumResults() == 1);
212 auto res = op->getResult(0);
214 auto array = op->getOperand(0);
215 auto it = operandVals.find(
array);
216 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
217 auto currVals = it->second;
219 std::vector<ConstrainRefIndex> indices;
221 for (
size_t i = 1; i < op->getNumOperands(); i++) {
222 auto currentOp = op->getOperand(i);
223 auto idxIt = operandVals.find(currentOp);
224 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
225 auto &idxVals = idxIt->second;
231 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
233 indices.push_back(idx);
236 auto arrayType = mlir::dyn_cast<ArrayType>(
array.getType());
237 auto lower = mlir::APInt::getZero(64);
238 mlir::APInt upper(64, arrayType.getDimSize(i - 1));
240 indices.push_back(idxRange);
244 auto [newVals, _] = currVals.extract(indices);
246 if (mlir::isa<ReadArrayOp>(op)) {
247 ensure(newVals.isScalar(),
"array read must produce a scalar value");
255 propagateIfChanged(after, after->
setValue(res, newVals));
261 mlir::ModuleOp m,
StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
264 if (cdg.computeConstraints(solver, am).failed()) {
265 return mlir::failure();
277 std::set<std::set<ConstrainRef>> sortedSets;
278 for (
auto it = signalSets.begin(); it != signalSets.end(); it++) {
279 if (!it->isLeader()) {
283 std::set<ConstrainRef> sortedMembers;
284 for (
auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
285 sortedMembers.insert(*mit);
290 if (sortedMembers.size() > 1) {
291 sortedSets.insert(sortedMembers);
295 for (
auto &[ref, constSet] : constantSets) {
296 if (constSet.empty()) {
299 std::set<ConstrainRef> sortedMembers(constSet.begin(), constSet.end());
300 sortedMembers.insert(ref);
301 sortedSets.insert(sortedMembers);
304 os <<
"ConstraintDependencyGraph { ";
306 for (
auto it = sortedSets.begin(); it != sortedSets.end();) {
308 for (
auto mit = it->begin(); mit != it->end();) {
311 if (mit != it->end()) {
317 if (it == sortedSets.end()) {
327mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
328 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
334 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
346 this->walkConstrainOp(solver, emitOp);
349 constrainFnOp.walk([
this, &solver](EmitContainmentOp emitOp) {
350 this->walkConstrainOp(solver, emitOp);
360 constrainFnOp.walk([
this, &solver, &am](CallOp fnCall)
mutable {
362 ensure(mlir::succeeded(res),
"could not resolve constrain call");
364 auto fn = res->get();
365 if (!fn.isStructConstrain()) {
369 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
372 auto lattice = solver.lookupState<ConstrainRefLattice>(fnCall.getOperation());
373 ensure(lattice,
"could not find lattice for call operation");
376 for (
unsigned i = 0; i < fn.getNumArguments(); i++) {
377 auto prefix = ConstrainRef(fn.getArgument(i));
378 auto val = lattice->getOrDefault(fnCall.getOperand(i));
379 translations.push_back({prefix, val});
381 auto &childAnalysis =
382 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
383 if (!childAnalysis.constructed()) {
385 mlir::succeeded(childAnalysis.runAnalysis(solver, am)),
386 "could not construct CDG for child struct"
389 auto translatedCDG = childAnalysis.getResult().translate(translations);
393 auto &tSets = translatedCDG.signalSets;
394 for (
auto lit = tSets.begin(); lit != tSets.end(); lit++) {
395 if (!lit->isLeader()) {
398 auto leader = lit->getData();
399 for (
auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
400 signalSets.unionSets(leader, *mit);
404 for (
auto &[ref, constSet] : translatedCDG.constantSets) {
405 constantSets[ref].insert(constSet.begin(), constSet.end());
409 return mlir::success();
412void ConstraintDependencyGraph::walkConstrainOp(
413 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
415 std::vector<ConstrainRef> signalUsages, constUsages;
416 auto lattice = solver.lookupState<ConstrainRefLattice>(emitOp);
417 ensure(lattice,
"failed to get lattice for emit operation");
419 for (
auto operand : emitOp->getOperands()) {
420 auto latticeVal = lattice->getOrDefault(operand);
421 for (
auto &ref : latticeVal.foldToScalar()) {
422 if (ref.isConstant()) {
423 constUsages.push_back(ref);
425 signalUsages.push_back(ref);
431 if (!signalUsages.empty()) {
432 auto it = signalUsages.begin();
433 auto leader = signalSets.getOrInsertLeaderValue(*it);
434 for (it++; it != signalUsages.end(); it++) {
435 signalSets.unionSets(leader, *it);
439 for (
auto &sig : signalUsages) {
440 constantSets[sig].insert(constUsages.begin(), constUsages.end());
448 ) -> mlir::FailureOr<std::vector<ConstrainRef>> {
449 std::vector<ConstrainRef> refs;
450 for (
auto &[prefix, vals] : translation) {
451 if (!elem.isValidPrefix(prefix)) {
455 if (vals.isArray()) {
457 auto suffix = elem.getSuffix(prefix);
459 mlir::succeeded(suffix),
"failure is nonsensical, we already checked for valid prefix"
462 auto [resolvedVals, _] = vals.extract(suffix.value());
463 auto folded = resolvedVals.foldToScalar();
464 refs.insert(refs.end(), folded.begin(), folded.end());
466 for (
auto &replacement : vals.getScalarValue()) {
467 auto translated = elem.translate(prefix, replacement);
468 if (mlir::succeeded(translated)) {
469 refs.push_back(translated.value());
475 return mlir::failure();
480 for (
auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
481 if (!leaderIt->isLeader()) {
485 std::vector<ConstrainRef> translatedSignals, translatedConsts;
486 for (
auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
488 if (mlir::failed(member)) {
491 for (
auto &ref : *member) {
492 if (ref.isConstant()) {
493 translatedConsts.push_back(ref);
495 translatedSignals.push_back(ref);
499 if (
auto it = constantSets.find(*mit); it != constantSets.end()) {
500 auto &origConstSet = it->second;
501 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
505 if (translatedSignals.empty()) {
510 auto it = translatedSignals.begin();
512 res.signalSets.insert(leader);
513 for (it++; it != translatedSignals.end(); it++) {
514 res.signalSets.insert(*it);
515 res.signalSets.unionSets(leader, *it);
519 for (
auto &ref : translatedSignals) {
520 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
528 auto currRef = mlir::FailureOr<ConstrainRef>(ref);
529 while (mlir::succeeded(currRef)) {
531 for (
auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
532 if (currRef.value() != *it) {
537 auto constIt = constantSets.find(*currRef);
538 if (constIt != constantSets.end()) {
539 res.insert(constIt->second.begin(), constIt->second.end());
542 currRef = currRef->getParentPrefix();
550 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager
554 if (mlir::failed(result)) {
555 return mlir::failure();
558 return mlir::success();
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Hook for customizing the behavior of lattice propagation along the call control flow edges.
mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void arraySubdivisionOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void setToEntryState(ConstrainRefLattice *lattice) override
Set the dense lattice at control flow entry point and propagate an update if it changed.
void visitOperation(mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Propagate constrain reference lattice values from operands to results.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
A lattice for use in dense analysis.
mlir::ChangeResult join(const AbstractDenseLattice &rhs) override
Maximum upper bound.
ConstrainRefLatticeValue getOrDefault(mlir::Value v) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseMap< mlir::Value, ConstrainRefLatticeValue > ValueMap
mlir::ChangeResult setValues(const ValueMap &rhs)
mlir::ChangeResult setValue(mlir::Value v, const ConstrainRefLatticeValue &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, NoContext &_) override
Perform the analysis and construct the Result output.
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am)
Compute a ConstraintDependencyGraph (CDG)
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph translate(ConstrainRefRemappings translation) const
Translate the ConstrainRefs in this CDG to that of a different context.
ConstrainRefSet getConstrainingValues(const ConstrainRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
void dump() const
Dumps the CDG to stderr.
component::StructDefOp getStruct() const
mlir::ModuleOp getModule() const
void setResult(ConstraintDependencyGraph &&r)
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
ConstrainRefLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
std::vector< std::pair< ConstrainRef, ConstrainRefLatticeValue > > ConstrainRefRemappings
void ensure(bool condition, llvm::Twine errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.