25#include "r1cs/Dialect/IR/Attrs.h"
26#include "r1cs/Dialect/IR/Ops.h"
27#include "r1cs/Dialect/IR/Types.h"
29#include <pcl/Dialect/IR/Dialect.h>
30#include <pcl/Dialect/IR/Ops.h>
31#include <pcl/Dialect/IR/Types.h>
33#include <mlir/Dialect/Func/IR/FuncOps.h>
34#include <mlir/IR/BuiltinOps.h>
36#include <llvm/ADT/DenseMap.h>
37#include <llvm/ADT/DenseMapInfo.h>
38#include <llvm/ADT/SmallVector.h>
39#include <llvm/Support/Debug.h>
46#define GEN_PASS_DECL_PCLLOWERINGPASS
47#define GEN_PASS_DEF_PCLLOWERINGPASS
62static FailureOr<Value> lookup(Value v, llvm::DenseMap<Value, Value> &m, Operation *onError) {
63 if (
auto it = m.find(v); it != m.end()) {
66 return onError->emitError(
"missing operand mapping");
69static void rememberResult(Value
from, Value to, llvm::DenseMap<Value, Value> &m) {
70 (void)m.try_emplace(
from, to);
74template <
typename SrcBinOp,
typename DstBinOp>
76lowerBinaryLike(OpBuilder &b, SrcBinOp src, llvm::DenseMap<Value, Value> &mapping) {
77 auto loc = src.getLoc();
78 auto op = src.getOperation();
79 auto lhs = lookup(src.getLhs(), mapping, op);
83 auto rhs = lookup(src.getRhs(), mapping, op);
88 auto dst = b.create<DstBinOp>(loc, *lhs, *rhs);
89 rememberResult(src.getResult(), dst.getRes(), mapping);
94lowerConst(OpBuilder &b,
FeltConstantOp cst, llvm::DenseMap<Value, Value> &mapping) {
95 auto attr = pcl::FeltAttr::get(b.getContext(), cst.
getValue());
96 auto dst = b.create<pcl::ConstOp>(cst.getLoc(), attr);
97 rememberResult(cst.
getResult(), dst.getRes(), mapping);
104 void getDependentDialects(DialectRegistry ®istry)
const override {
105 registry.insert<pcl::PCLDialect, func::FuncDialect>();
109 LogicalResult validateStruct(StructDefOp structDef) {
111 auto fieldType = field.getType();
112 if (!llvm::isa<FeltType>(fieldType)) {
113 return field.emitError() <<
"Field must be felt type. Found " << fieldType
114 <<
" for field: " << field.getName();
132 emitAssertEqOptimized(OpBuilder &b, Location loc, Value lhsVal, Value rhsVal) {
134 auto isBool = [](Value v) {
return llvm::isa<pcl::BoolType>(v.getType()); };
136 auto getConstAPInt = [](Value v) -> std::optional<llvm::APInt> {
137 if (
auto c = llvm::dyn_cast_if_present<pcl::ConstOp>(v.getDefiningOp())) {
139 return c.getValue().getValue().getValue();
144 auto isConstOne = [&](Value v) {
145 if (
auto ap = getConstAPInt(v)) {
150 auto isConstZero = [&](Value v) {
151 if (
auto ap = getConstAPInt(v)) {
157 auto emitEqAssert = [&](Value l, Value r) {
158 auto eq = b.create<pcl::CmpEqOp>(loc, l, r);
159 b.create<pcl::AssertOp>(loc, eq.getRes());
162 auto emitAssertTrue = [&](Value pred) { b.create<pcl::AssertOp>(loc, pred); };
164 auto emitAssertFalse = [&](Value pred) {
165 auto neg = b.create<pcl::NotOp>(loc, pred);
166 b.create<pcl::AssertOp>(loc,
neg.getRes());
170 if (isBool(lhsVal) && isConstOne(rhsVal)) {
172 emitAssertTrue(lhsVal);
175 if (isBool(rhsVal) && isConstOne(lhsVal)) {
177 emitAssertTrue(rhsVal);
180 if (isBool(lhsVal) && isConstZero(rhsVal)) {
182 emitAssertFalse(lhsVal);
185 if (isBool(rhsVal) && isConstZero(lhsVal)) {
187 emitAssertFalse(rhsVal);
192 emitEqAssert(lhsVal, rhsVal);
197 LogicalResult lowerStructToPCLBody(
StructDefOp structDef, func::FuncOp dstFunc) {
199 llvm::DenseMap<Value, Value> llzkToPcl;
200 OpBuilder b(dstFunc.getBody());
202 llvm::DenseMap<StringRef, Value> field2pclvar;
203 llvm::SmallVector<Value> outVars;
206 auto srcArgs = srcFunc.getArguments().drop_front();
207 auto dstArgs = dstFunc.getArguments();
208 if (dstArgs.size() != srcArgs.size()) {
209 return srcFunc.emitError(
"arg count mismatch after dropping self");
213 for (
auto [src, dst] : llvm::zip(srcArgs, dstArgs)) {
214 llzkToPcl.try_emplace(src, dst);
219 b.create<pcl::VarOp>(fieldDef.getLoc(), fieldDef.getName(), fieldDef.hasPublicAttr());
220 field2pclvar.insert({fieldDef.getName(), pclVar});
221 if (fieldDef.hasPublicAttr()) {
222 outVars.push_back(pclVar);
225 if (!srcFunc.getBody().hasOneBlock()) {
226 return srcFunc.emitError(
227 "llzk-to-pcl translation assumes the constrain function body has 1 block"
233 for (Operation &op : srcEntry) {
234 LogicalResult res = success();
235 llvm::TypeSwitch<Operation *, void>(&op)
237 res = lowerConst(b, c, llzkToPcl);
239 .Case<AddFeltOp>([&b, &llzkToPcl, &res](
auto a) {
240 res = lowerBinaryLike<AddFeltOp, pcl::AddOp>(b, a, llzkToPcl);
242 .Case<SubFeltOp>([&b, &llzkToPcl, &res](
auto s) {
243 res = lowerBinaryLike<SubFeltOp, pcl::SubOp>(b, s, llzkToPcl);
245 .Case<MulFeltOp>([&b, &llzkToPcl, &res](
auto m) {
246 res = lowerBinaryLike<MulFeltOp, pcl::MulOp>(b, m, llzkToPcl);
248 .Case<IntToFeltOp>([&llzkToPcl, &res](
auto m) {
249 auto arg = lookup(m.getValue(), llzkToPcl, m.getOperation());
254 rememberResult(m.getResult(), arg.value(), llzkToPcl);
256 .Case<CmpOp>([&b, &llzkToPcl, &res](
auto cmp) {
257 auto pred =
cmp.getPredicate();
260 res = lowerBinaryLike<CmpOp, pcl::CmpEqOp>(b,
cmp, llzkToPcl);
264 auto eq = lowerBinaryLike<CmpOp, pcl::CmpEqOp>(b,
cmp, llzkToPcl);
270 auto eqRes = lookup(
cmp.getResult(), llzkToPcl,
cmp.getOperation());
275 auto loc =
cmp.getLoc();
276 auto neg = b.create<pcl::NegOp>(loc, *eqRes);
278 rememberResult(
cmp.getResult(),
neg.getResult(), llzkToPcl);
282 res = lowerBinaryLike<CmpOp, pcl::CmpLtOp>(b,
cmp, llzkToPcl);
285 res = lowerBinaryLike<CmpOp, pcl::CmpLeOp>(b,
cmp, llzkToPcl);
288 res = lowerBinaryLike<CmpOp, pcl::CmpGtOp>(b,
cmp, llzkToPcl);
291 res = lowerBinaryLike<CmpOp, pcl::CmpGeOp>(b,
cmp, llzkToPcl);
295 .Case<EmitEqualityOp>([&b, &llzkToPcl, &res](
auto eq) {
296 auto lhs = lookup(eq.getLhs(), llzkToPcl, eq.getOperation());
297 auto rhs = lookup(eq.getRhs(), llzkToPcl, eq.getOperation());
298 if (failed(lhs) || failed(rhs)) {
303 Value lhsVal = *lhs, rhsVal = *rhs;
304 auto loc = eq.getLoc();
305 if (failed(emitAssertEqOptimized(b, loc, lhsVal, rhsVal))) {
310 .Case<FieldReadOp>([&field2pclvar, &llzkToPcl, &srcFunc](
auto read) {
314 assert(read.getComponent() == srcFunc.getArguments()[0]);
315 if (
auto it = field2pclvar.find(read.getFieldName()); it != field2pclvar.end()) {
316 rememberResult(read.getResult(), it->getSecond(), llzkToPcl);
318 llvm_unreachable(
"Every field should have been mapped to a pcl var");
321 .Case<ReturnOp>([&b, &outVars](
auto ret) {
323 b.create<pcl::ReturnOp>(
324 ret.getLoc(), (llvm::SmallVector<Value>(outVars.begin(), outVars.end()))
326 }).Default([](Operation *unknown) {
327 unknown->emitError(
"unsupported op in PCL lowering: ") << unknown->getName();
336 FailureOr<func::FuncOp> buildPCLFunc(StructDefOp structDef) {
337 SmallVector<Type> pclInputTypes, pclOutputTypes;
339 auto ctx = structDef.getContext();
340 for (
auto arg : constrainFunc.getArguments().drop_front()) {
341 auto argType = arg.getType();
342 if (!llvm::isa<FeltType>(argType)) {
343 return constrainFunc.emitError()
344 <<
"Constrain function's args are expected to be felts. Found " << argType
345 <<
"for arg #: " << arg.getArgNumber();
347 pclInputTypes.push_back(pcl::FeltType::get(ctx));
350 auto fieldType = field.getType();
351 if (!llvm::isa<FeltType>(fieldType)) {
352 return structDef.emitError() <<
"Field must be felt type. Found " << fieldType
353 <<
" for field: " << field.getName();
355 if (field.hasPublicAttr()) {
356 pclOutputTypes.push_back(pcl::FeltType::get(ctx));
359 FunctionType fty = FunctionType::get(ctx, pclInputTypes, pclOutputTypes);
360 auto func = func::FuncOp::create(constrainFunc.getLoc(), structDef.getName(), fty);
361 func.addEntryBlock();
366 void setPrime(ModuleOp &newMod) {
368 auto newBitWidth = prime.getBitWidth() + 1;
369 auto ty = IntegerType::get(newMod.getContext(), newBitWidth);
370 auto intAttr = IntegerAttr::get(ty, prime.zext(newBitWidth));
371 newMod->setAttr(
"pcl.prime", pcl::PrimeAttr::get(newMod.getContext(), intAttr));
374 void runOnOperation()
override {
375 ModuleOp moduleOp = getOperation();
377 assert(moduleOp->getContext()->getLoadedDialect<pcl::PCLDialect>() &&
"PCL dialect not loaded");
379 auto newMod = ModuleOp::create(moduleOp.getLoc());
383 auto walkResult = moduleOp.walk([
this, &newMod](StructDefOp structDef) -> WalkResult {
385 if (failed(validateStruct(structDef))) {
386 return WalkResult::interrupt();
389 FailureOr<func::FuncOp> pclFuncOp = buildPCLFunc(structDef);
390 if (failed(pclFuncOp)) {
391 return WalkResult::interrupt();
394 newMod.getBody()->push_back(*pclFuncOp);
395 if (failed(lowerStructToPCLBody(structDef, pclFuncOp.value()))) {
396 return WalkResult::interrupt();
399 return WalkResult::advance();
401 if (walkResult.wasInterrupted()) {
406 moduleOp.getRegion().takeBody(newMod.getBodyRegion());
408 moduleOp->setAttrs(newMod->getAttrDictionary());
414std::unique_ptr<Pass> llzk::createPCLLoweringPass() {
return std::make_unique<PCLLoweringPass>(); }
This file implements helper methods for constructing DynamicAPInts.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::mlir::TypedValue<::llzk::felt::FeltType > getResult()
::mlir::Region & getBody()
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)