LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKPCLLoweringPass.cpp
Go to the documentation of this file.
1//===-- LLZKPCLLoweringPass.cpp --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
15#include "llzk/Config/Config.h"
25#include "r1cs/Dialect/IR/Attrs.h"
26#include "r1cs/Dialect/IR/Ops.h"
27#include "r1cs/Dialect/IR/Types.h"
28
29#include <pcl/Dialect/IR/Dialect.h>
30#include <pcl/Dialect/IR/Ops.h>
31#include <pcl/Dialect/IR/Types.h>
32
33#include <mlir/Dialect/Func/IR/FuncOps.h>
34#include <mlir/IR/BuiltinOps.h>
35
36#include <llvm/ADT/DenseMap.h>
37#include <llvm/ADT/DenseMapInfo.h>
38#include <llvm/ADT/SmallVector.h>
39#include <llvm/Support/Debug.h>
40
41#include <deque>
42#include <memory>
43
44// Include the generated base pass class definitions.
45namespace llzk {
46#define GEN_PASS_DECL_PCLLOWERINGPASS
47#define GEN_PASS_DEF_PCLLOWERINGPASS
49} // namespace llzk
50
51using namespace mlir;
52using namespace llzk;
53using namespace llzk::cast;
54using namespace llzk::boolean;
55using namespace llzk::constrain;
56using namespace llzk::felt;
57using namespace llzk::function;
58using namespace llzk::component;
59
60namespace {
61
62static FailureOr<Value> lookup(Value v, llvm::DenseMap<Value, Value> &m, Operation *onError) {
63 if (auto it = m.find(v); it != m.end()) {
64 return it->second;
65 }
66 return onError->emitError("missing operand mapping");
67}
68
69static void rememberResult(Value from, Value to, llvm::DenseMap<Value, Value> &m) {
70 (void)m.try_emplace(from, to);
71}
72
73// Convert binary LLZK op to corresponding binary PCL op
74template <typename SrcBinOp, typename DstBinOp>
75static LogicalResult
76lowerBinaryLike(OpBuilder &b, SrcBinOp src, llvm::DenseMap<Value, Value> &mapping) {
77 auto loc = src.getLoc();
78 auto op = src.getOperation();
79 auto lhs = lookup(src.getLhs(), mapping, op);
80 if (failed(lhs)) {
81 return failure();
82 }
83 auto rhs = lookup(src.getRhs(), mapping, op);
84 if (failed(rhs)) {
85 return failure();
86 }
87
88 auto dst = b.create<DstBinOp>(loc, *lhs, *rhs);
89 rememberResult(src.getResult(), dst.getRes(), mapping);
90 return success();
91}
92
93static LogicalResult
94lowerConst(OpBuilder &b, FeltConstantOp cst, llvm::DenseMap<Value, Value> &mapping) {
95 auto attr = pcl::FeltAttr::get(b.getContext(), cst.getValue());
96 auto dst = b.create<pcl::ConstOp>(cst.getLoc(), attr);
97 rememberResult(cst.getResult(), dst.getRes(), mapping);
98 return success();
99}
100
101class PCLLoweringPass : public llzk::impl::PCLLoweringPassBase<PCLLoweringPass> {
102
103private:
104 void getDependentDialects(DialectRegistry &registry) const override {
105 registry.insert<pcl::PCLDialect, func::FuncDialect>();
106 }
107
109 LogicalResult validateStruct(StructDefOp structDef) {
110 for (auto field : structDef.getFieldDefs()) {
111 auto fieldType = field.getType();
112 if (!llvm::isa<FeltType>(fieldType)) {
113 return field.emitError() << "Field must be felt type. Found " << fieldType
114 << " for field: " << field.getName();
115 }
116 }
117 return success();
118 }
119
131 static LogicalResult
132 emitAssertEqOptimized(OpBuilder &b, Location loc, Value lhsVal, Value rhsVal) {
133 // --- Small helpers --------------------------------------------------------
134 auto isBool = [](Value v) { return llvm::isa<pcl::BoolType>(v.getType()); };
135
136 auto getConstAPInt = [](Value v) -> std::optional<llvm::APInt> {
137 if (auto c = llvm::dyn_cast_if_present<pcl::ConstOp>(v.getDefiningOp())) {
138 // Chain: ConstOp -> FeltAttr (or BoolAttr-as-int) -> IntegerAttr -> APInt
139 return c.getValue().getValue().getValue();
140 }
141 return std::nullopt;
142 };
143
144 auto isConstOne = [&](Value v) {
145 if (auto ap = getConstAPInt(v)) {
146 return ap->isOne();
147 }
148 return false;
149 };
150 auto isConstZero = [&](Value v) {
151 if (auto ap = getConstAPInt(v)) {
152 return ap->isZero();
153 }
154 return false;
155 };
156
157 auto emitEqAssert = [&](Value l, Value r) {
158 auto eq = b.create<pcl::CmpEqOp>(loc, l, r);
159 b.create<pcl::AssertOp>(loc, eq.getRes());
160 };
161
162 auto emitAssertTrue = [&](Value pred) { b.create<pcl::AssertOp>(loc, pred); };
163
164 auto emitAssertFalse = [&](Value pred) {
165 auto neg = b.create<pcl::NotOp>(loc, pred);
166 b.create<pcl::AssertOp>(loc, neg.getRes());
167 };
168
169 // Optimized handling of boolean patterns
170 if (isBool(lhsVal) && isConstOne(rhsVal)) {
171 // bool == 1 → assert(bool)
172 emitAssertTrue(lhsVal);
173 return success();
174 }
175 if (isBool(rhsVal) && isConstOne(lhsVal)) {
176 // 1 == bool → assert(bool)
177 emitAssertTrue(rhsVal);
178 return success();
179 }
180 if (isBool(lhsVal) && isConstZero(rhsVal)) {
181 // bool == 0 → assert(!bool)
182 emitAssertFalse(lhsVal);
183 return success();
184 }
185 if (isBool(rhsVal) && isConstZero(lhsVal)) {
186 // 0 == bool → assert(!bool)
187 emitAssertFalse(rhsVal);
188 return success();
189 }
190
191 // Fallback to assert(lhs == rhs)
192 emitEqAssert(lhsVal, rhsVal);
193 return success();
195
197 LogicalResult lowerStructToPCLBody(StructDefOp structDef, func::FuncOp dstFunc) {
198 // As we build, map llzk values to their pcl ones
199 llvm::DenseMap<Value, Value> llzkToPcl;
200 OpBuilder b(dstFunc.getBody());
201 // Map field name to PCL vars; public fields are outputs, privates are intermediates
202 llvm::DenseMap<StringRef, Value> field2pclvar;
203 llvm::SmallVector<Value> outVars;
204
205 auto srcFunc = structDef.getConstrainFuncOp();
206 auto srcArgs = srcFunc.getArguments().drop_front();
207 auto dstArgs = dstFunc.getArguments();
208 if (dstArgs.size() != srcArgs.size()) {
209 return srcFunc.emitError("arg count mismatch after dropping self");
210 }
212 // 1-1 mapping of args from constraint args to PCL args
213 for (auto [src, dst] : llvm::zip(srcArgs, dstArgs)) {
214 llzkToPcl.try_emplace(src, dst);
215 }
216 for (auto fieldDef : structDef.getFieldDefs()) {
217 // Create a PCL var for each struct field. Public fields are outputs in PCL
218 auto pclVar =
219 b.create<pcl::VarOp>(fieldDef.getLoc(), fieldDef.getName(), fieldDef.hasPublicAttr());
220 field2pclvar.insert({fieldDef.getName(), pclVar});
221 if (fieldDef.hasPublicAttr()) {
222 outVars.push_back(pclVar);
223 }
224 }
225 if (!srcFunc.getBody().hasOneBlock()) {
226 return srcFunc.emitError(
227 "llzk-to-pcl translation assumes the constrain function body has 1 block"
228 );
229 }
230 Block &srcEntry = srcFunc.getBody().front();
231 // Translate each op. Almost 1-1 and currently only support Felt ops.
232 // TODO: Support calls, if-else, globals/lookups.
233 for (Operation &op : srcEntry) {
234 LogicalResult res = success();
235 llvm::TypeSwitch<Operation *, void>(&op)
236 .Case<FeltConstantOp>([&b, &llzkToPcl, &res](auto c) {
237 res = lowerConst(b, c, llzkToPcl);
238 })
239 .Case<AddFeltOp>([&b, &llzkToPcl, &res](auto a) {
240 res = lowerBinaryLike<AddFeltOp, pcl::AddOp>(b, a, llzkToPcl);
241 })
242 .Case<SubFeltOp>([&b, &llzkToPcl, &res](auto s) {
243 res = lowerBinaryLike<SubFeltOp, pcl::SubOp>(b, s, llzkToPcl);
244 })
245 .Case<MulFeltOp>([&b, &llzkToPcl, &res](auto m) {
246 res = lowerBinaryLike<MulFeltOp, pcl::MulOp>(b, m, llzkToPcl);
247 })
248 .Case<IntToFeltOp>([&llzkToPcl, &res](auto m) {
249 auto arg = lookup(m.getValue(), llzkToPcl, m.getOperation());
250 if (failed(arg)) {
251 res = failure();
252 return;
253 }
254 rememberResult(m.getResult(), arg.value(), llzkToPcl);
255 })
256 .Case<CmpOp>([&b, &llzkToPcl, &res](auto cmp) {
257 auto pred = cmp.getPredicate();
258 switch (pred) {
260 res = lowerBinaryLike<CmpOp, pcl::CmpEqOp>(b, cmp, llzkToPcl);
261 break;
263 // Translate not-equals as an equality followed by a negation
264 auto eq = lowerBinaryLike<CmpOp, pcl::CmpEqOp>(b, cmp, llzkToPcl);
265 if (failed(eq)) {
266 res = eq;
267 break;
268 }
269 // Get the result from the `pcl::CmpEqOp` to pass into `Neg`
270 auto eqRes = lookup(cmp.getResult(), llzkToPcl, cmp.getOperation());
271 if (failed(eqRes)) {
272 res = failure();
273 break;
274 }
275 auto loc = cmp.getLoc();
276 auto neg = b.create<pcl::NegOp>(loc, *eqRes);
277 // Associate the result of the llzk-op with the result of the pcl-neg
278 rememberResult(cmp.getResult(), neg.getResult(), llzkToPcl);
279 break;
280 }
282 res = lowerBinaryLike<CmpOp, pcl::CmpLtOp>(b, cmp, llzkToPcl);
283 break;
285 res = lowerBinaryLike<CmpOp, pcl::CmpLeOp>(b, cmp, llzkToPcl);
286 break;
288 res = lowerBinaryLike<CmpOp, pcl::CmpGtOp>(b, cmp, llzkToPcl);
289 break;
291 res = lowerBinaryLike<CmpOp, pcl::CmpGeOp>(b, cmp, llzkToPcl);
292 break;
293 }
294 })
295 .Case<EmitEqualityOp>([&b, &llzkToPcl, &res](auto eq) {
296 auto lhs = lookup(eq.getLhs(), llzkToPcl, eq.getOperation());
297 auto rhs = lookup(eq.getRhs(), llzkToPcl, eq.getOperation());
298 if (failed(lhs) || failed(rhs)) {
299 res = failure();
300 return;
301 }
302
303 Value lhsVal = *lhs, rhsVal = *rhs;
304 auto loc = eq.getLoc();
305 if (failed(emitAssertEqOptimized(b, loc, lhsVal, rhsVal))) {
306 res = failure();
307 return;
308 }
309 })
310 .Case<FieldReadOp>([&field2pclvar, &llzkToPcl, &srcFunc](auto read) {
311 // At this point every field in the struct should have a var associated with it
312 // so we should simply retrieve the var associated with the field.
313 (void)srcFunc; // to silence unused variable warning if asserts are disabled
314 assert(read.getComponent() == srcFunc.getArguments()[0]);
315 if (auto it = field2pclvar.find(read.getFieldName()); it != field2pclvar.end()) {
316 rememberResult(read.getResult(), it->getSecond(), llzkToPcl);
317 } else {
318 llvm_unreachable("Every field should have been mapped to a pcl var");
319 }
320 })
321 .Case<ReturnOp>([&b, &outVars](auto ret) {
322 // We return all the output vars we defined above.
323 b.create<pcl::ReturnOp>(
324 ret.getLoc(), (llvm::SmallVector<Value>(outVars.begin(), outVars.end()))
325 );
326 }).Default([](Operation *unknown) {
327 unknown->emitError("unsupported op in PCL lowering: ") << unknown->getName();
328 });
329 if (failed(res)) {
330 return failure();
331 }
332 }
333 return success();
334 }
335
336 FailureOr<func::FuncOp> buildPCLFunc(StructDefOp structDef) {
337 SmallVector<Type> pclInputTypes, pclOutputTypes;
338 auto constrainFunc = structDef.getConstrainFuncOp();
339 auto ctx = structDef.getContext();
340 for (auto arg : constrainFunc.getArguments().drop_front()) {
341 auto argType = arg.getType();
342 if (!llvm::isa<FeltType>(argType)) {
343 return constrainFunc.emitError()
344 << "Constrain function's args are expected to be felts. Found " << argType
345 << "for arg #: " << arg.getArgNumber();
346 }
347 pclInputTypes.push_back(pcl::FeltType::get(ctx));
348 }
349 for (auto field : structDef.getFieldDefs()) {
350 auto fieldType = field.getType();
351 if (!llvm::isa<FeltType>(fieldType)) {
352 return structDef.emitError() << "Field must be felt type. Found " << fieldType
353 << " for field: " << field.getName();
354 }
355 if (field.hasPublicAttr()) {
356 pclOutputTypes.push_back(pcl::FeltType::get(ctx));
357 }
358 }
359 FunctionType fty = FunctionType::get(ctx, pclInputTypes, pclOutputTypes);
360 auto func = func::FuncOp::create(constrainFunc.getLoc(), structDef.getName(), fty);
361 func.addEntryBlock();
362 return func;
363 }
364
365 // PCL programs require a module-level attribute specifying the prime.
366 void setPrime(ModuleOp &newMod) {
367 // Add an extra bit to avoid the prime being represented as a negative number
368 auto newBitWidth = prime.getBitWidth() + 1;
369 auto ty = IntegerType::get(newMod.getContext(), newBitWidth);
370 auto intAttr = IntegerAttr::get(ty, prime.zext(newBitWidth));
371 newMod->setAttr("pcl.prime", pcl::PrimeAttr::get(newMod.getContext(), intAttr));
372 }
373
374 void runOnOperation() override {
375 ModuleOp moduleOp = getOperation();
376 // check PCLDialect is loaded.
377 assert(moduleOp->getContext()->getLoadedDialect<pcl::PCLDialect>() && "PCL dialect not loaded");
378 // Create the PCL module
379 auto newMod = ModuleOp::create(moduleOp.getLoc());
380 // Set the prime attribute
381 setPrime(newMod);
382 // Convert each struct to a PCL function
383 auto walkResult = moduleOp.walk([this, &newMod](StructDefOp structDef) -> WalkResult {
384 // 1) verify the struct can be converted to PCL
385 if (failed(validateStruct(structDef))) {
386 return WalkResult::interrupt();
387 }
388 // 2) Construct the PCL function op but with an empty body
389 FailureOr<func::FuncOp> pclFuncOp = buildPCLFunc(structDef);
390 if (failed(pclFuncOp)) {
391 return WalkResult::interrupt();
392 }
393 // 3) Fill in the PCL function body
394 newMod.getBody()->push_back(*pclFuncOp);
395 if (failed(lowerStructToPCLBody(structDef, pclFuncOp.value()))) {
396 return WalkResult::interrupt();
397 }
398
399 return WalkResult::advance();
400 });
401 if (walkResult.wasInterrupted()) {
402 signalPassFailure();
403 return;
404 }
405 // clear the original ops
406 moduleOp.getRegion().takeBody(newMod.getBodyRegion());
407 // Replace the module attributes
408 moduleOp->setAttrs(newMod->getAttrDictionary());
409 newMod.erase();
410 }
411};
412} // namespace
413
414std::unique_ptr<Pass> llzk::createPCLLoweringPass() { return std::make_unique<PCLLoweringPass>(); }
This file implements helper methods for constructing DynamicAPInts.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
Definition Ops.cpp:419
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:433
::llvm::APInt getValue()
Definition Ops.cpp.inc:714
::mlir::TypedValue<::llzk::felt::FeltType > getResult()
Definition Ops.h.inc:772
::mlir::Region & getBody()
Definition Ops.h.inc:607
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)