20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/IR/Value.h>
23#include <llvm/Support/Debug.h>
26#include <unordered_set>
28#define DEBUG_TYPE "llzk-cdg"
35using namespace component;
36using namespace constrain;
37using namespace function;
46 llvm::dbgs() <<
"ConstrainRefAnalysis::visitCallControlFlowTransfer: " << call <<
'\n'
49 ensure(succeeded(fnOpRes),
"could not resolve called function");
52 llvm::dbgs().indent(4) <<
"parent op is ";
53 if (
auto s = call->getParentOfType<
StructDefOp>()) {
54 llvm::dbgs() << s.getName();
55 }
else if (
auto p = call->getParentOfType<
FuncDefOp>()) {
56 llvm::dbgs() << p.getName();
58 llvm::dbgs() <<
"<UNKNOWN PARENT TYPE>";
66 if (action == dataflow::CallControlFlowAction::EnterCallee) {
79 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
83 if (
auto *prev = call->getPrevNode()) {
88 ensure(beforeCall,
"could not get prior lattice");
91 std::unordered_map<ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash> translation;
93 ensure(mlir::succeeded(funcOpRes),
"could not lookup called function");
94 auto funcOp = funcOpRes->get();
96 auto callOp = mlir::dyn_cast<CallOp>(call.getOperation());
97 ensure(callOp,
"call is not a llzk::CallOp");
99 for (
unsigned i = 0; i < funcOp.getNumArguments(); i++) {
101 auto val = beforeCall->
getOrDefault(callOp.getOperand(i));
102 translation[key] = val;
109 mlir::ChangeResult updated = after->
join(*beforeCall);
110 for (
unsigned i = 0; i < callOp.getNumResults(); i++) {
112 auto [translatedVal, _] = retVal.
translate(translation);
113 updated |= after->
setValue(callOp->getResult(i), translatedVal);
115 for (
const auto &[val, refVal] : before.
getMap()) {
116 auto [translatedVal, _] = refVal.translate(translation);
117 updated |= after->
setValue(val, translatedVal);
119 propagateIfChanged(after, updated);
125 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
138 LLVM_DEBUG(llvm::dbgs() <<
"ConstrainRefAnalysis::visitOperation: " << *op <<
'\n');
141 for (mlir::OpOperand &operand : op->getOpOperands()) {
142 operandVals[operand.get()] = before.
getOrDefault(operand.get());
150 ChangeResult res = after->
setValues(operandVals);
153 if (
auto fieldRefOp = mlir::dyn_cast<FieldRefOpInterface>(op)) {
155 auto fieldOpRes = fieldRefOp.getFieldDefOp(tables);
156 ensure(mlir::succeeded(fieldOpRes),
"could not find field read");
159 if (fieldRefOp.isRead()) {
160 fieldRefRes = fieldRefOp.getVal();
162 fieldRefRes = fieldRefOp;
165 const auto &ops = operandVals.at(fieldRefOp.getComponent());
166 auto [fieldVals, _] = ops.referenceField(fieldOpRes.value());
168 res |= after->
setValue(fieldRefRes, fieldVals);
169 }
else if (
auto arrayAccessOp = mlir::dyn_cast<ArrayAccessOpInterface>(op)) {
172 }
else if (
auto createArray = mlir::dyn_cast<CreateArrayOp>(op)) {
179 const auto &elements = createArray.getElements();
180 if (!elements.empty()) {
181 for (
unsigned i = 0; i < elements.size(); i++) {
182 auto currentOp = elements[i];
183 auto &opVals = operandVals[currentOp];
184 (void)newArrayVal.getElemFlatIdx(i).setValue(opVals);
188 auto createArrayRes = createArray.getResult();
190 res |= after->
setValue(createArrayRes, newArrayVal);
191 }
else if (
auto structNewOp = mlir::dyn_cast<CreateStructOp>(op)) {
192 auto newOpRes = structNewOp.getResult();
194 res |= after->
setValue(newOpRes, newStructValue);
202 propagateIfChanged(after, res);
210 auto updated = mlir::ChangeResult::NoChange;
211 for (
auto res : op->getResults()) {
214 for (
auto &[_, opVal] : operandVals) {
215 (void)cur.update(opVal);
217 updated |= after->
setValue(res, cur);
231 if (mlir::isa<ReadArrayOp, ExtractArrayOp>(arrayAccessOp)) {
232 res = arrayAccessOp->getResult(0);
238 auto it = operandVals.find(
array);
239 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
240 auto currVals = it->second;
242 std::vector<ConstrainRefIndex> indices;
244 for (
unsigned i = 0; i < arrayAccessOp.
getIndices().size(); ++i) {
245 auto idxOperand = arrayAccessOp.
getIndices()[i];
246 auto idxIt = operandVals.find(idxOperand);
247 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
248 auto &idxVals = idxIt->second;
254 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
256 indices.push_back(idx);
259 auto arrayType = mlir::dyn_cast<ArrayType>(
array.getType());
260 auto lower = mlir::APInt::getZero(64);
261 mlir::APInt upper(64, arrayType.getDimSize(i));
263 indices.push_back(idxRange);
267 auto [newVals, _] = currVals.extract(indices);
269 if (mlir::isa<ReadArrayOp, WriteArrayOp>(arrayAccessOp)) {
270 ensure(newVals.isScalar(),
"array read/write must produce a scalar value");
278 propagateIfChanged(after, after->
setValue(res, newVals));
284 mlir::ModuleOp m,
StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am,
285 bool runIntraprocedural
288 if (cdg.computeConstraints(solver, am).failed()) {
289 return mlir::failure();
301 std::set<std::set<ConstrainRef>> sortedSets;
302 for (
auto it = signalSets.begin(); it != signalSets.end(); it++) {
303 if (!it->isLeader()) {
307 std::set<ConstrainRef> sortedMembers;
308 for (
auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
309 sortedMembers.insert(*mit);
314 if (sortedMembers.size() > 1) {
315 sortedSets.insert(sortedMembers);
319 for (
auto &[ref, constSet] : constantSets) {
320 if (constSet.empty()) {
323 std::set<ConstrainRef> sortedMembers(constSet.begin(), constSet.end());
324 sortedMembers.insert(ref);
325 sortedSets.insert(sortedMembers);
328 os <<
"ConstraintDependencyGraph { ";
330 for (
auto it = sortedSets.begin(); it != sortedSets.end();) {
332 for (
auto mit = it->begin(); mit != it->end();) {
335 if (mit != it->end()) {
341 if (it == sortedSets.end()) {
351mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
352 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
358 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
369 constrainFnOp.walk([
this, &solver](Operation *op) {
374 for (
auto &[ref, vals] : refLattice->getRef2Val()) {
375 ref2Val[ref].insert(vals.begin(), vals.end());
378 if (isa<EmitEqualityOp, EmitContainmentOp>(op)) {
379 this->walkConstrainOp(solver, op);
390 auto fnCallWalker = [
this, &solver, &am](CallOp fnCall)
mutable {
392 ensure(mlir::succeeded(res),
"could not resolve constrain call");
394 auto fn = res->get();
395 if (!fn.isStructConstrain()) {
399 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
402 auto lattice = solver.lookupState<ConstrainRefLattice>(fnCall.getOperation());
403 ensure(lattice,
"could not find lattice for call operation");
406 for (
unsigned i = 0; i < fn.getNumArguments(); i++) {
407 auto prefix = ConstrainRef(fn.getArgument(i));
408 auto val = lattice->getOrDefault(fnCall.getOperand(i));
409 translations.push_back({prefix, val});
411 auto &childAnalysis =
412 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
413 if (!childAnalysis.constructed()) {
415 mlir::succeeded(childAnalysis.runAnalysis(solver, am,
false)),
416 "could not construct CDG for child struct"
419 auto translatedCDG = childAnalysis.getResult().translate(translations);
423 auto &tSets = translatedCDG.signalSets;
424 for (
auto lit = tSets.begin(); lit != tSets.end(); lit++) {
425 if (!lit->isLeader()) {
428 auto leader = lit->getData();
429 for (
auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
430 signalSets.unionSets(leader, *mit);
434 for (
auto &[ref, constSet] : translatedCDG.constantSets) {
435 constantSets[ref].insert(constSet.begin(), constSet.end());
438 if (!runIntraprocedural) {
439 constrainFnOp.walk(fnCallWalker);
442 return mlir::success();
445void ConstraintDependencyGraph::walkConstrainOp(
446 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
448 std::vector<ConstrainRef> signalUsages, constUsages;
450 const ConstrainRefLattice *refLattice = solver.lookupState<ConstrainRefLattice>(emitOp);
451 ensure(refLattice,
"missing lattice for constrain op");
453 for (
auto operand : emitOp->getOperands()) {
454 auto latticeVal = refLattice->getOrDefault(operand);
455 for (
auto &ref : latticeVal.foldToScalar()) {
456 if (ref.isConstant()) {
457 constUsages.push_back(ref);
459 signalUsages.push_back(ref);
465 if (!signalUsages.empty()) {
466 auto it = signalUsages.begin();
467 auto leader = signalSets.getOrInsertLeaderValue(*it);
468 for (it++; it != signalUsages.end(); it++) {
469 signalSets.unionSets(leader, *it);
473 for (
auto &sig : signalUsages) {
474 constantSets[sig].insert(constUsages.begin(), constUsages.end());
482 ) -> mlir::FailureOr<std::vector<ConstrainRef>> {
483 std::vector<ConstrainRef> refs;
484 for (
auto &[prefix, vals] : translation) {
485 if (!elem.isValidPrefix(prefix)) {
489 if (vals.isArray()) {
491 auto suffix = elem.getSuffix(prefix);
493 mlir::succeeded(suffix),
"failure is nonsensical, we already checked for valid prefix"
496 auto [resolvedVals, _] = vals.extract(suffix.value());
497 auto folded = resolvedVals.foldToScalar();
498 refs.insert(refs.end(), folded.begin(), folded.end());
500 for (
auto &replacement : vals.getScalarValue()) {
501 auto translated = elem.translate(prefix, replacement);
502 if (mlir::succeeded(translated)) {
503 refs.push_back(translated.value());
509 return mlir::failure();
514 for (
auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
515 if (!leaderIt->isLeader()) {
519 std::vector<ConstrainRef> translatedSignals, translatedConsts;
520 for (
auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
522 if (mlir::failed(member)) {
525 for (
auto &ref : *member) {
526 if (ref.isConstant()) {
527 translatedConsts.push_back(ref);
529 translatedSignals.push_back(ref);
533 if (
auto it = constantSets.find(*mit); it != constantSets.end()) {
534 auto &origConstSet = it->second;
535 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
539 if (translatedSignals.empty()) {
544 auto it = translatedSignals.begin();
546 res.signalSets.insert(leader);
547 for (it++; it != translatedSignals.end(); it++) {
548 res.signalSets.insert(*it);
549 res.signalSets.unionSets(leader, *it);
553 for (
auto &ref : translatedSignals) {
554 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
559 for (
auto &[ref, vals] : ref2Val) {
561 if (succeeded(translationRes)) {
562 for (
const auto &translatedRef : *translationRes) {
563 res.ref2Val[translatedRef].insert(vals.begin(), vals.end());
573 auto currRef = mlir::FailureOr<ConstrainRef>(ref);
574 while (mlir::succeeded(currRef)) {
576 for (
auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
577 if (currRef.value() != *it) {
582 auto constIt = constantSets.find(*currRef);
583 if (constIt != constantSets.end()) {
584 res.insert(constIt->second.begin(), constIt->second.end());
587 currRef = currRef->getParentPrefix();
595 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager,
596 bool runIntraprocedural
601 if (mlir::failed(result)) {
602 return mlir::failure();
605 return mlir::success();
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Hook for customizing the behavior of lattice propagation along the call control flow edges.
mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void setToEntryState(ConstrainRefLattice *lattice) override
Set the dense lattice at control flow entry point and propagate an update if it changed.
void arraySubdivisionOpUpdate(array::ArrayAccessOpInterface op, const ConstrainRefLattice::ValueMap &operandVals, const ConstrainRefLattice &before, ConstrainRefLattice *after)
void visitOperation(mlir::Operation *op, const ConstrainRefLattice &before, ConstrainRefLattice *after) override
Propagate constrain reference lattice values from operands to results.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
A lattice for use in dense analysis.
mlir::DenseMap< ValueTy, ConstrainRefLatticeValue > ValueMap
mlir::ChangeResult join(const AbstractDenseLattice &rhs) override
Maximum upper bound.
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::ChangeResult setValues(const ValueMap &rhs)
mlir::ChangeResult setValue(ValueTy v, const ConstrainRefLatticeValue &rhs)
ConstrainRefLatticeValue getOrDefault(ValueTy v) const
const ValueMap & getMap() const
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
Defines a reference to a llzk object within a constrain function call.
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, CDGAnalysisContext &ctx) override
Perform the analysis and construct the Result output.
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, bool runIntraprocedural)
Compute a ConstraintDependencyGraph (CDG)
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph translate(ConstrainRefRemappings translation) const
Translate the ConstrainRefs in this CDG to that of a different context.
ConstrainRefSet getConstrainingValues(const ConstrainRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
void dump() const
Dumps the CDG to stderr.
component::StructDefOp getStruct() const
mlir::ModuleOp getModule() const
void setResult(ConstraintDependencyGraph &&r)
::mlir::Operation::operand_range getIndices()
Gets the operand range containing the index for each dimension.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
ConstrainRefLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::CallControlFlowAction CallControlFlowAction
std::vector< std::pair< ConstrainRef, ConstrainRefLatticeValue > > ConstrainRefRemappings
void ensure(bool condition, llvm::Twine errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.