19#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
20#include <mlir/IR/Value.h>
22#include <llvm/Support/Debug.h>
25#include <unordered_set>
27#define DEBUG_TYPE "llzk-cdg"
32using namespace component;
33using namespace constrain;
34using namespace function;
43 llvm::dbgs() <<
"ConstrainRefAnalysis::visitCallControlFlowTransfer: " << call <<
'\n'
46 ensure(succeeded(fnOpRes),
"could not resolve called function");
49 llvm::dbgs().indent(4) <<
"parent op is ";
50 if (
auto s = call->getParentOfType<
StructDefOp>()) {
51 llvm::dbgs() << s.getName();
52 }
else if (
auto p = call->getParentOfType<
FuncDefOp>()) {
53 llvm::dbgs() << p.getName();
55 llvm::dbgs() <<
"<UNKNOWN PARENT TYPE>";
63 if (action == dataflow::CallControlFlowAction::EnterCallee) {
76 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
80 if (
auto *prev = call->getPrevNode()) {
85 ensure(beforeCall,
"could not get prior lattice");
88 std::unordered_map<ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash> translation;
90 ensure(mlir::succeeded(funcOpRes),
"could not lookup called function");
91 auto funcOp = funcOpRes->get();
93 auto callOp = mlir::dyn_cast<CallOp>(call.getOperation());
94 ensure(callOp,
"call is not a llzk::CallOp");
96 for (
unsigned i = 0; i < funcOp.getNumArguments(); i++) {
99 translation[key] = val;
104 mlir::ChangeResult updated = after->
join(*beforeCall);
105 for (
unsigned i = 0; i < callOp.getNumResults(); i++) {
107 auto [translatedVal, _] = retVal.
translate(translation);
108 updated |= after->
setValue(callOp->getResult(i), translatedVal);
110 propagateIfChanged(after, updated);
116 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
129 LLVM_DEBUG(llvm::dbgs() <<
"ConstrainRefAnalysis::visitOperation: " << *op <<
'\n');
132 for (
auto &operand : op->getOpOperands()) {
133 operandVals[operand.get()] = before.
getOrDefault(operand.get());
140 if (
auto fieldRead = mlir::dyn_cast<FieldReadOp>(op)) {
142 assert(operandVals.size() == 1);
143 assert(fieldRead->getNumResults() == 1);
145 auto fieldOpRes = fieldRead.getFieldDefOp(tables);
146 ensure(mlir::succeeded(fieldOpRes),
"could not find field read");
148 auto res = fieldRead->getResult(0);
149 const auto &ops = operandVals.at(fieldRead->getOpOperand(0).get());
150 auto [fieldVals, _] = ops.referenceField(fieldOpRes.value());
152 propagateIfChanged(after, after->
setValue(res, fieldVals));
153 }
else if (mlir::isa<ReadArrayOp>(op)) {
155 }
else if (
auto createArray = mlir::dyn_cast<CreateArrayOp>(op)) {
161 for (
unsigned i = 0; i < createArray.getNumOperands(); i++) {
162 auto currentOp = createArray.getOperand(i);
163 auto &opVals = operandVals[currentOp];
164 (void)newArrayVal.getElemFlatIdx(i).setValue(opVals);
167 assert(createArray->getNumResults() == 1);
168 auto res = createArray->getResult(0);
170 propagateIfChanged(after, after->
setValue(res, newArrayVal));
171 }
else if (
auto extractArray = mlir::dyn_cast<ExtractArrayOp>(op)) {
177 propagateIfChanged(after,
fallbackOpUpdate(op, operandVals, before, after));
186 auto updated = mlir::ChangeResult::NoChange;
187 for (
auto res : op->getResults()) {
190 for (
auto &[_, opVal] : operandVals) {
191 (void)cur.update(opVal);
193 updated |= after->
setValue(res, cur);
205 ensure(mlir::isa<ReadArrayOp, ExtractArrayOp>(op),
"wrong type of op provided!");
208 assert(op->getNumResults() == 1);
209 auto res = op->getResult(0);
211 auto array = op->getOperand(0);
212 auto it = operandVals.find(
array);
213 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
214 auto currVals = it->second;
216 std::vector<ConstrainRefIndex> indices;
218 for (
size_t i = 1; i < op->getNumOperands(); i++) {
219 auto currentOp = op->getOperand(i);
220 auto idxIt = operandVals.find(currentOp);
221 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
222 auto &idxVals = idxIt->second;
224 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstantIndex()) {
226 indices.push_back(idx);
229 auto arrayType = mlir::dyn_cast<ArrayType>(
array.getType());
230 auto lower = mlir::APInt::getZero(64);
231 mlir::APInt upper(64, arrayType.getDimSize(i - 1));
233 indices.push_back(idxRange);
237 auto [newVals, _] = currVals.extract(indices);
239 propagateIfChanged(after, after->
setValue(res, newVals));
245 mlir::ModuleOp m,
StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
248 if (cdg.computeConstraints(solver, am).failed()) {
249 return mlir::failure();
261 std::set<std::set<ConstrainRef>> sortedSets;
262 for (
auto it = signalSets.begin(); it != signalSets.end(); it++) {
263 if (!it->isLeader()) {
267 std::set<ConstrainRef> sortedMembers;
268 for (
auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
269 sortedMembers.insert(*mit);
274 if (sortedMembers.size() > 1) {
275 sortedSets.insert(sortedMembers);
279 for (
auto &[ref, constSet] : constantSets) {
280 if (constSet.empty()) {
283 std::set<ConstrainRef> sortedMembers(constSet.begin(), constSet.end());
284 sortedMembers.insert(ref);
285 sortedSets.insert(sortedMembers);
288 os <<
"ConstraintDependencyGraph { ";
290 for (
auto it = sortedSets.begin(); it != sortedSets.end();) {
292 for (
auto mit = it->begin(); mit != it->end();) {
295 if (mit != it->end()) {
301 if (it == sortedSets.end()) {
311mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
312 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
318 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
330 this->walkConstrainOp(solver, emitOp);
333 constrainFnOp.walk([
this, &solver](EmitContainmentOp emitOp) {
334 this->walkConstrainOp(solver, emitOp);
344 constrainFnOp.walk([
this, &solver, &am](CallOp fnCall)
mutable {
346 ensure(mlir::succeeded(res),
"could not resolve constrain call");
348 auto fn = res->get();
349 if (!fn.isStructConstrain()) {
353 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
356 auto lattice = solver.lookupState<ConstrainRefLattice>(fnCall.getOperation());
357 ensure(lattice,
"could not find lattice for call operation");
360 for (
unsigned i = 0; i < fn.getNumArguments(); i++) {
361 auto prefix = ConstrainRef(fn.getArgument(i));
362 auto val = lattice->getOrDefault(fnCall.getOperand(i));
363 translations.push_back({prefix, val});
365 auto &childAnalysis =
366 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
367 if (!childAnalysis.constructed()) {
369 mlir::succeeded(childAnalysis.runAnalysis(solver, am)),
370 "could not construct CDG for child struct"
373 auto translatedCDG = childAnalysis.getResult().translate(translations);
377 auto &tSets = translatedCDG.signalSets;
378 for (
auto lit = tSets.begin(); lit != tSets.end(); lit++) {
379 if (!lit->isLeader()) {
382 auto leader = lit->getData();
383 for (
auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
384 signalSets.unionSets(leader, *mit);
388 for (
auto &[ref, constSet] : translatedCDG.constantSets) {
389 constantSets[ref].insert(constSet.begin(), constSet.end());
393 return mlir::success();
396void ConstraintDependencyGraph::walkConstrainOp(
397 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
399 std::vector<ConstrainRef> signalUsages, constUsages;
400 auto lattice = solver.lookupState<ConstrainRefLattice>(emitOp);
401 ensure(lattice,
"failed to get lattice for emit operation");
403 for (
auto operand : emitOp->getOperands()) {
404 auto latticeVal = lattice->getOrDefault(operand);
405 for (
auto &ref : latticeVal.foldToScalar()) {
406 if (ref.isConstant()) {
407 constUsages.push_back(ref);
409 signalUsages.push_back(ref);
415 if (!signalUsages.empty()) {
416 auto it = signalUsages.begin();
417 auto leader = signalSets.getOrInsertLeaderValue(*it);
418 for (it++; it != signalUsages.end(); it++) {
419 signalSets.unionSets(leader, *it);
423 for (
auto &sig : signalUsages) {
424 constantSets[sig].insert(constUsages.begin(), constUsages.end());
432 ) -> mlir::FailureOr<std::vector<ConstrainRef>> {
433 std::vector<ConstrainRef> refs;
434 for (
auto &[prefix, vals] : translation) {
435 if (!elem.isValidPrefix(prefix)) {
439 if (vals.isArray()) {
441 auto suffix = elem.getSuffix(prefix);
443 mlir::succeeded(suffix),
"failure is nonsensical, we already checked for valid prefix"
446 auto [resolvedVals, _] = vals.extract(suffix.value());
447 auto folded = resolvedVals.foldToScalar();
448 refs.insert(refs.end(), folded.begin(), folded.end());
450 for (
auto &replacement : vals.getScalarValue()) {
451 auto translated = elem.translate(prefix, replacement);
452 if (mlir::succeeded(translated)) {
453 refs.push_back(translated.value());
459 return mlir::failure();
464 for (
auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
465 if (!leaderIt->isLeader()) {
469 std::vector<ConstrainRef> translatedSignals, translatedConsts;
470 for (
auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
472 if (mlir::failed(member)) {
475 for (
auto &ref : *member) {
476 if (ref.isConstant()) {
477 translatedConsts.push_back(ref);
479 translatedSignals.push_back(ref);
483 if (
auto it = constantSets.find(*mit); it != constantSets.end()) {
484 auto &origConstSet = it->second;
485 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
489 if (translatedSignals.empty()) {
494 auto it = translatedSignals.begin();
496 res.signalSets.insert(leader);
497 for (it++; it != translatedSignals.end(); it++) {
498 res.signalSets.insert(*it);
499 res.signalSets.unionSets(leader, *it);
503 for (
auto &ref : translatedSignals) {
504 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
512 auto currRef = mlir::FailureOr<ConstrainRef>(ref);
513 while (mlir::succeeded(currRef)) {
515 for (
auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
516 if (currRef.value() != *it) {
521 auto constIt = constantSets.find(*currRef);
522 if (constIt != constantSets.end()) {
523 res.insert(constIt->second.begin(), constIt->second.end());
526 currRef = currRef->getParentPrefix();
534 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager
538 if (mlir::failed(result)) {
539 return mlir::failure();
542 return mlir::success();
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Hook for customizing the behavior of lattice propagation along the call control flow edges.
mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void arraySubdivisionOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void setToEntryState(ConstrainRefLattice *lattice) override
Set the dense lattice at control flow entry point and propagate an update if it changed.
void visitOperation(mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Propagate constrain reference lattice values from operands to results.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
A lattice for use in dense analysis.
mlir::ChangeResult join(const AbstractDenseLattice &rhs) override
Maximum upper bound.
ConstrainRefLatticeValue getOrDefault(mlir::Value v) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseMap< mlir::Value, ConstrainRefLatticeValue > ValueMap
mlir::ChangeResult setValue(mlir::Value v, const ConstrainRefLatticeValue &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, NoContext &_) override
Perform the analysis and construct the Result output.
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am)
Compute a ConstraintDependencyGraph (CDG)
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph translate(ConstrainRefRemappings translation) const
Translate the ConstrainRefs in this CDG to that of a different context.
ConstrainRefSet getConstrainingValues(const ConstrainRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
void dump() const
Dumps the CDG to stderr.
component::StructDefOp getStruct() const
mlir::ModuleOp getModule() const
void setResult(ConstraintDependencyGraph &&r)
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
ConstrainRefLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
std::vector< std::pair< ConstrainRef, ConstrainRefLatticeValue > > ConstrainRefRemappings
void ensure(bool condition, llvm::Twine errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.