20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/IR/Value.h>
23#include <llvm/Support/Debug.h>
26#include <unordered_set>
28#define DEBUG_TYPE "llzk-cdg"
35using namespace component;
36using namespace constrain;
37using namespace function;
45 LLVM_DEBUG(llvm::dbgs() <<
"SourceRefAnalysis::visitCallControlFlowTransfer: " << call <<
'\n');
47 ensure(succeeded(fnOpRes),
"could not resolve called function");
50 llvm::dbgs().indent(4) <<
"parent op is ";
51 if (
auto s = call->getParentOfType<
StructDefOp>()) {
52 llvm::dbgs() << s.getName();
53 }
else if (
auto p = call->getParentOfType<
FuncDefOp>()) {
54 llvm::dbgs() << p.getName();
56 llvm::dbgs() <<
"<UNKNOWN PARENT TYPE>";
64 if (action == dataflow::CallControlFlowAction::EnterCallee) {
77 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
81 ensure(beforeCall,
"could not get prior lattice");
84 std::unordered_map<SourceRef, SourceRefLatticeValue, SourceRef::Hash> translation;
86 ensure(mlir::succeeded(funcOpRes),
"could not lookup called function");
87 auto funcOp = funcOpRes->get();
89 auto callOp = llvm::dyn_cast<CallOp>(call.getOperation());
90 ensure(callOp,
"call is not a CallOp");
92 for (
unsigned i = 0; i < funcOp.getNumArguments(); i++) {
97 Value operand = callOp.getOperand(i);
98 if (Operation *defOp = operand.getDefiningOp()) {
99 operandLattice =
getLattice(getProgramPointAfter(defOp));
102 translation[key] = operandLattice->
getOrDefault(operand);
106 mlir::ChangeResult updated = mlir::ChangeResult::NoChange;
107 for (
unsigned i = 0; i < callOp.getNumResults(); i++) {
109 auto [translatedVal, _] = retVal.
translate(translation);
110 updated |= after->
setValue(callOp->getResult(i), translatedVal);
112 propagateIfChanged(after, updated);
118 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
131 LLVM_DEBUG(llvm::dbgs() <<
"SourceRefAnalysis::visitOperation: " << *op <<
'\n');
134 for (OpOperand &operand : op->getOpOperands()) {
137 Value operandVal = operand.get();
138 if (Operation *defOp = operandVal.getDefiningOp()) {
139 prior =
getLattice(getProgramPointAfter(defOp));
142 operandVals[operandVal] = prior->
getOrDefault(operandVal);
147 ChangeResult res = after->
setValues(operandVals);
150 if (
auto fieldRefOp = llvm::dyn_cast<FieldRefOpInterface>(op)) {
152 auto fieldOpRes = fieldRefOp.getFieldDefOp(tables);
153 ensure(mlir::succeeded(fieldOpRes),
"could not find field read");
156 if (fieldRefOp.isRead()) {
157 fieldRefRes = fieldRefOp.getVal();
159 fieldRefRes = fieldRefOp;
162 const auto &ops = operandVals.at(fieldRefOp.getComponent());
163 auto [fieldVals, _] = ops.referenceField(fieldOpRes.value());
165 res |= after->
setValue(fieldRefRes, fieldVals);
166 }
else if (
auto arrayAccessOp = llvm::dyn_cast<ArrayAccessOpInterface>(op)) {
169 }
else if (
auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
175 const auto &elements = createArray.getElements();
176 if (!elements.empty()) {
177 for (
unsigned i = 0; i < elements.size(); i++) {
178 auto currentOp = elements[i];
179 auto &opVals = operandVals[currentOp];
184 auto createArrayRes = createArray.getResult();
186 res |= after->
setValue(createArrayRes, newArrayVal);
187 }
else if (
auto structNewOp = llvm::dyn_cast<CreateStructOp>(op)) {
188 auto newOpRes = structNewOp.getResult();
190 res |= after->
setValue(newOpRes, newStructValue);
198 propagateIfChanged(after, res);
199 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"lattice is of size " << after->
size() <<
'\n');
208 auto updated = mlir::ChangeResult::NoChange;
209 for (
auto res : op->getResults()) {
212 for (
auto &[_, opVal] : operandVals) {
213 (void)cur.update(opVal);
215 updated |= after->
setValue(res, cur);
229 if (llvm::isa<ReadArrayOp, ExtractArrayOp>(arrayAccessOp)) {
230 res = arrayAccessOp->getResult(0);
236 auto it = operandVals.find(
array);
237 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
238 auto currVals = it->second;
240 std::vector<SourceRefIndex> indices;
242 for (
unsigned i = 0; i < arrayAccessOp.
getIndices().size(); ++i) {
243 auto idxOperand = arrayAccessOp.
getIndices()[i];
244 auto idxIt = operandVals.find(idxOperand);
245 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
246 auto &idxVals = idxIt->second;
252 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
254 indices.push_back(idx);
257 auto arrayType = llvm::dyn_cast<ArrayType>(
array.getType());
258 auto lower = mlir::APInt::getZero(64);
259 mlir::APInt upper(64, arrayType.getDimSize(i));
261 indices.push_back(idxRange);
265 auto [newVals, _] = currVals.extract(indices);
267 if (llvm::isa<ReadArrayOp, WriteArrayOp>(arrayAccessOp)) {
268 ensure(newVals.isScalar(),
"array read/write must produce a scalar value");
276 propagateIfChanged(after, after->
setValue(res, newVals));
282 mlir::ModuleOp m,
StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am,
286 if (cdg.computeConstraints(solver, am).failed()) {
287 return mlir::failure();
299 std::set<std::set<SourceRef>> sortedSets;
300 for (
auto it = signalSets.begin(); it != signalSets.end(); it++) {
301 if (!it->isLeader()) {
305 std::set<SourceRef> sortedMembers;
306 for (
auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
307 sortedMembers.insert(*mit);
312 if (sortedMembers.size() > 1) {
313 sortedSets.insert(sortedMembers);
317 for (
auto &[ref, constSet] : constantSets) {
318 if (constSet.empty()) {
321 std::set<SourceRef> sortedMembers(constSet.begin(), constSet.end());
322 sortedMembers.insert(ref);
323 sortedSets.insert(sortedMembers);
326 os <<
"ConstraintDependencyGraph { ";
328 for (
auto it = sortedSets.begin(); it != sortedSets.end();) {
330 for (
auto mit = it->begin(); mit != it->end();) {
333 if (mit != it->end()) {
339 if (it == sortedSets.end()) {
349mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
350 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
356 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
367 constrainFnOp.walk([
this, &solver](Operation *op) {
368 ProgramPoint *pp = solver.getProgramPointAfter(op);
373 for (
auto &[ref, vals] : refLattice->getRef2Val()) {
374 ref2Val[ref].insert(vals.begin(), vals.end());
377 if (isa<EmitEqualityOp, EmitContainmentOp>(op)) {
378 this->walkConstrainOp(solver, op);
389 auto fnCallWalker = [
this, &solver, &am](CallOp fnCall)
mutable {
391 ensure(mlir::succeeded(res),
"could not resolve constrain call");
393 auto fn = res->get();
394 if (!fn.isStructConstrain()) {
398 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
401 ProgramPoint *pp = solver.getProgramPointAfter(fnCall.getOperation());
402 auto *afterCallLattice = solver.lookupState<SourceRefLattice>(pp);
403 ensure(afterCallLattice,
"could not find lattice for call operation");
406 for (
unsigned i = 0; i < fn.getNumArguments(); i++) {
407 SourceRef prefix(fn.getArgument(i));
410 const SourceRefLattice *operandLattice = afterCallLattice;
411 Value operand = fnCall.getOperand(i);
412 if (Operation *defOp = operand.getDefiningOp()) {
413 ProgramPoint *defPoint = solver.getProgramPointAfter(defOp);
414 operandLattice = solver.lookupState<SourceRefLattice>(defPoint);
416 ensure(operandLattice,
"could not find lattice for call operand");
418 SourceRefLatticeValue val = operandLattice->getOrDefault(operand);
419 translations.push_back({prefix, val});
421 auto &childAnalysis =
422 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
423 if (!childAnalysis.constructed(ctx)) {
425 mlir::succeeded(childAnalysis.runAnalysis(solver, am, {.runIntraprocedural = false})),
426 "could not construct CDG for child struct"
429 auto translatedCDG = childAnalysis.getResult(ctx).translate(translations);
431 const auto &translatedRef2Val = translatedCDG.getRef2Val();
432 ref2Val.insert(translatedRef2Val.begin(), translatedRef2Val.end());
436 auto &tSets = translatedCDG.signalSets;
437 for (
auto lit = tSets.begin(); lit != tSets.end(); lit++) {
438 if (!lit->isLeader()) {
441 auto leader = lit->getData();
442 for (
auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
443 signalSets.unionSets(leader, *mit);
447 for (
auto &[ref, constSet] : translatedCDG.constantSets) {
448 constantSets[ref].insert(constSet.begin(), constSet.end());
451 if (!ctx.runIntraproceduralAnalysis()) {
452 constrainFnOp.walk(fnCallWalker);
455 return mlir::success();
458void ConstraintDependencyGraph::walkConstrainOp(
459 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
461 std::vector<SourceRef> signalUsages, constUsages;
463 ProgramPoint *pp = solver.getProgramPointAfter(emitOp);
464 const SourceRefLattice *refLattice = solver.lookupState<SourceRefLattice>(pp);
465 ensure(refLattice,
"missing lattice for constrain op");
467 for (
auto operand : emitOp->getOperands()) {
468 auto latticeVal = refLattice->getOrDefault(operand);
469 for (
auto &ref : latticeVal.foldToScalar()) {
470 if (ref.isConstant()) {
471 constUsages.push_back(ref);
473 signalUsages.push_back(ref);
479 if (!signalUsages.empty()) {
480 auto it = signalUsages.begin();
481 auto leader = signalSets.getOrInsertLeaderValue(*it);
482 for (it++; it != signalUsages.end(); it++) {
483 signalSets.unionSets(leader, *it);
487 for (
auto &sig : signalUsages) {
488 constantSets[sig].insert(constUsages.begin(), constUsages.end());
496 [&translation](
const SourceRef &elem) -> mlir::FailureOr<std::vector<SourceRef>> {
497 std::vector<SourceRef> refs;
498 for (
auto &[prefix, vals] : translation) {
499 if (!elem.isValidPrefix(prefix)) {
503 if (vals.isArray()) {
505 auto suffix = elem.getSuffix(prefix);
507 mlir::succeeded(suffix),
"failure is nonsensical, we already checked for valid prefix"
510 auto [resolvedVals, _] = vals.extract(suffix.value());
511 auto folded = resolvedVals.foldToScalar();
512 refs.insert(refs.end(), folded.begin(), folded.end());
514 for (
auto &replacement : vals.getScalarValue()) {
515 auto translated = elem.translate(prefix, replacement);
516 if (mlir::succeeded(translated)) {
517 refs.push_back(translated.value());
523 return mlir::failure();
528 for (
auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
529 if (!leaderIt->isLeader()) {
533 std::vector<SourceRef> translatedSignals, translatedConsts;
534 for (
auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
536 if (mlir::failed(member)) {
539 for (
auto &ref : *member) {
540 if (ref.isConstant()) {
541 translatedConsts.push_back(ref);
543 translatedSignals.push_back(ref);
547 if (
auto it = constantSets.find(*mit); it != constantSets.end()) {
548 auto &origConstSet = it->second;
549 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
553 if (translatedSignals.empty()) {
558 auto it = translatedSignals.begin();
560 res.signalSets.insert(leader);
561 for (it++; it != translatedSignals.end(); it++) {
562 res.signalSets.insert(*it);
563 res.signalSets.unionSets(leader, *it);
567 for (
auto &ref : translatedSignals) {
568 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
573 for (
auto &[ref, vals] : ref2Val) {
575 if (succeeded(translationRes)) {
576 for (
const auto &translatedRef : *translationRes) {
577 res.ref2Val[translatedRef].insert(vals.begin(), vals.end());
587 auto currRef = mlir::FailureOr<SourceRef>(ref);
588 while (mlir::succeeded(currRef)) {
590 for (
auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
591 if (currRef.value() != *it) {
596 auto constIt = constantSets.find(*currRef);
597 if (constIt != constantSets.end()) {
598 res.insert(constIt->second.begin(), constIt->second.end());
601 currRef = currRef->getParentPrefix();
609 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager,
615 if (mlir::failed(result)) {
616 return mlir::failure();
619 return mlir::success();
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, const CDGAnalysisContext &ctx) override
Construct a CDG, using the module's analysis manager to query ConstraintDependencyGraph objects for n...
A dependency graph of constraints enforced by an LLZK struct.
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const CDGAnalysisContext &ctx)
Compute a ConstraintDependencyGraph (CDG)
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
SourceRefSet getConstrainingValues(const SourceRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
void dump() const
Dumps the CDG to stderr.
ConstraintDependencyGraph translate(SourceRefRemappings translation) const
Translate the SourceRefs in this CDG to that of a different context.
mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const SourceRefLattice::ValueMap &operandVals, const SourceRefLattice &before, SourceRefLattice *after)
mlir::LogicalResult visitOperation(mlir::Operation *op, const SourceRefLattice &before, SourceRefLattice *after) override
Propagate SourceRef lattice values from operands to results.
void setToEntryState(SourceRefLattice *lattice) override
Set the dense lattice at control flow entry point and propagate an update if it changed.
void arraySubdivisionOpUpdate(array::ArrayAccessOpInterface op, const SourceRefLattice::ValueMap &operandVals, const SourceRefLattice &before, SourceRefLattice *after)
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const SourceRefLattice &before, SourceRefLattice *after) override
Hook for customizing the behavior of lattice propagation along the call control flow edges.
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
A lattice for use in dense analysis.
mlir::DenseMap< ValueTy, SourceRefLatticeValue > ValueMap
mlir::ChangeResult setValues(const ValueMap &rhs)
SourceRefLatticeValue getOrDefault(ValueTy v) const
mlir::ChangeResult setValue(ValueTy v, const SourceRefLatticeValue &rhs)
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
SourceRefLatticeValue getReturnValue(unsigned i) const
A reference to a "source", which is the base value from which other SSA values are derived.
component::StructDefOp getStruct() const
void setResult(const CDGAnalysisContext &ctx, ConstraintDependencyGraph &&r)
mlir::ModuleOp getModule() const
::mlir::Operation::operand_range getIndices()
Gets the operand range containing the index for each dimension.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
const Derived & getElemFlatIdx(unsigned i) const
Directly index into the flattened array using a single index.
SourceRefLattice * getLattice(mlir::LatticeAnchor anchor) override
mlir::dataflow::CallControlFlowAction CallControlFlowAction
std::vector< std::pair< SourceRef, SourceRefLatticeValue > > SourceRefRemappings
void ensure(bool condition, const llvm::Twine &errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
Parameters and shared objects to pass to child analyses.