안녕하세요, 지난시간에 저희는 GKR protocol 에 대하여 살펴보았습니다.
이번 시간에는, GKR 구현체 코드에 대하여 살펴보며, 어떻게 실제로 GKR protocol을 구현할 수 있는지 살펴보도록 하겠습니다. 흥미롭게도 한국인 이더리움 개발자께서 만들어 두셨네요.
오늘 글에서는 모든 소스 코드를 살펴보기보다, 핵심 파일인 gkr.py 를 위주로 살펴보며, 해당 레포지토리에 코드가 있지만 설명이 필요한 경우는 글을 통해 설명하며 글을 진행해 보도록 하겠습니다.
본격적으로 들어가기에 앞서, 해당 구현은 지난 시간 살펴본 구현보다 조금의 테크닉이 더 가미되어 있습니다. 좀 더 정확히는, 링크의 65 페이지를 참조하면 되는데요, 코드를 살펴보며 해당 부분이 나오면 언급드리며 살펴보도록 하겠습니다.

증명 과정
우선 코드는 증명에 필요한 class들을 정의하며 시작합니다. 우리가 지난시간 살펴본 것처럼, 게이트에 해당하는 Node, 증명을 위한 각 Layer, 그리고 증명을 실제로 진행할 Circuit을 살펴볼 수 있습니다.
class Node:
def __init__(self, binary_index: list[int], value, left=None, right=None):
self.binary_index = binary_index
self.value = value
self.left = left
self.right = right
class Layer:
def __init__(self) -> None:
self.nodes = []
def def_mult(self, mult):
self.mult = mult
def def_add(self, add):
self.add = add
def get_node(self, index) -> Node:
return self.nodes[index]
def add_node(self, index, node) -> None:
self.nodes.insert(index, node)
def add_func(self, func):
self.func = func
def len(self):
return len(self.nodes)
class Circuit:
def __init__(self, depth):
layers = []
for _ in range(depth):
layers.append(Layer())
self.layers : list[Layer] = layers # type: ignore
def get_node(self, layer, index):
return self.layers[layer].get_node(index)
def add_node(self, layer, index, binary_index, value, left=None, right=None):
self.layers[layer].add_node(index, Node(binary_index, value, left, right))
def depth(self):
return len(self.layers)
def layer_length(self, layer):
return self.layers[layer].len()
def k_i(self, layer):
return int(math.log2(self.layer_length(layer)))
def add_i(self, i):
return self.layers[i].add
def mult_i(self, i):
return self.layers[i].mult
def w_i(self, i):
return self.layers[i].func
또한 reduce_multiple_polynomial 함수가 있습니다. 위에 있는 GKR protocol 사진을 살펴보면, 게이트의 두 input을 $b$와 $c$ 로 나타내는 것을 볼 수 있습니다. 나중에 우리는 임의의 $r$ 포인트에서 $f_i$식을 검증해야 하기 때문에, 두 input $b$, 그리고 $c$를 interpolate 하는 linear polynomial을 만들어내는 과정이 표기되어 있습니다. 해당 polynomial은 $0$에서 $b$의 값을, $1$에서 $c$의 값을 가짐을 알 수 있습니다. 이렇게 해줌으로, verifier는 랜덤값 $r$에서의 evaluation 만으로 실제 circuit에서의 값 $w$를 알아낼 수 없습니다.
또한 그 이후에 verifier에게 전달될 Proof에 대한 class 정의가 되어있는것을 확인할 수 있습니다.
def reduce_multiple_polynomial(b: list[field.FQ], c: list[field.FQ], w: polynomial) -> list[field.FQ]:
assert(len(b) == len(c))
t = []
new_poly_terms = []
for b_i, c_i in zip(b, c):
new_const = b_i
gradient = c_i - b_i
t.append(term(gradient, 1, new_const))
for mono in w.terms:
new_terms = []
for each in mono.terms:
new_term = t[each.x_i - 1] * each.coeff
new_term.const += each.const
new_terms.append(new_term)
new_poly_terms.append(monomial(mono.coeff, new_terms))
poly = polynomial(new_poly_terms, w.constant)
return poly.get_all_coefficients()
class Proof:
def __init__(self, proofs, r, f, D, q, z, r_stars, d, w, adds, mults, k) -> None:
self.sumcheck_proofs : list[list[list[field.FQ]]] = proofs
self.sumcheck_r : list[list[field.FQ]] = r
self.f : list[field.FQ] = f
self.D : list[list[field.FQ]] = D
self.q : list[list[field.FQ]] = q
self.z : list[list[field.FQ]] = z
self.r : list[field.FQ] = r_stars
# circuit info
self.d : int = d
self.input_func : list[list[field.FQ]] = w
self.add : list[list[list[field.FQ]]] = adds
self.mult : list[list[list[field.FQ]]] = mults
self.k : list[int] = k
def to_dict(self):
to_serialize = dict()
to_serialize['sumcheckProof'] = list(map(lambda x: list(map(lambda y: list(map(lambda z: repr(z), y)), x)), self.sumcheck_proofs))
to_serialize['sumcheckr'] = list(map(lambda x: list(map(lambda y: repr(y), x)), self.sumcheck_r))
to_serialize['f'] = list(map(lambda x: repr(x), self.f))
to_serialize['q'] = list(map(lambda x: list(map(lambda y: repr(y), x)), self.q))
to_serialize['z'] = list(map(lambda x: list(map(lambda y: repr(y), x)), self.z))
to_serialize['D'] = list(map(lambda x: list(map(lambda y: repr(y), x)), self.D))
to_serialize['r'] = list(map(lambda x: repr(x), self.r))
to_serialize['inputFunc'] = list(map(lambda x: list(map(lambda y: repr(y), x)), self.input_func))
to_serialize['add'] = list(map(lambda x: list(map(lambda y: list(map(lambda z: repr(z), y)), x)), self.add))
to_serialize['mult'] = list(map(lambda x: list(map(lambda y: list(map(lambda z: repr(z), y)), x)), self.mult))
return to_serialize
이제 본격적인 prove 함수를 살펴보겠습니다.
def prove(circuit: Circuit, D):
start_time = time.time()
D_poly = get_multi_ext(D, circuit.k_i(0))
z = [[]] * circuit.depth()
z[0] = [field.FQ.zero()] * circuit.k_i(0)
sumcheck_proofs = []
q = []
f_res = []
sumcheck_r = []
r_stars = []
for i in range(len(z[0])):
z[0][i] = field.FQ.random() # This initial value is unsafe
우선, proof generation에 사용될 랜덤 값들을 $z$에 담아주겠습니다.
for i in range(circuit.depth() - 1):
add_i_ext = get_ext(circuit.add_i(i), circuit.k_i(i) + 2 * circuit.k_i(i + 1))
for j, r in enumerate(z[i]):
add_i_ext = add_i_ext.eval_i(r, j + 1)
mult_i_ext = get_ext(circuit.mult_i(i), circuit.k_i(i) + 2 * circuit.k_i(i + 1))
for j, r in enumerate(z[i]):
mult_i_ext = mult_i_ext.eval_i(r, j + 1)
w_i_ext_b = get_ext_from_k(circuit.w_i(i + 1), circuit.k_i(i + 1), circuit.k_i(i) + 1)
w_i_ext_c = get_ext_from_k(circuit.w_i(i + 1), circuit.k_i(i + 1), circuit.k_i(i) + circuit.k_i(i + 1) + 1)
first = add_i_ext * (w_i_ext_b + w_i_ext_c)
second = mult_i_ext * w_i_ext_b * w_i_ext_c
f = first + second
start_idx = circuit.k_i(i) + 1
sumcheck_proof, r = prove_sumcheck(f, 2 * circuit.k_i(i + 1), start_idx)
sumcheck_proofs.append(sumcheck_proof)
sumcheck_r.append(r)
b_star = r[0: circuit.k_i(i + 1)]
c_star = r[circuit.k_i(i + 1):(2 * circuit.k_i(i + 1))]
next_w = get_ext(circuit.w_i(i + 1), circuit.k_i(i + 1))
q_i = reduce_multiple_polynomial(b_star, c_star, next_w)
q.append(q_i)
그 후에는, 앞선 글에서 살펴본 것처럼 각 circuit의 layer마다 proof를 생성해 주게 됩니다.
우선 각 레이어에서 $add$ 와 $mult$에 대한 multilinear extension을 만들어주게 된 후, 해당 extension을 랜덤한 값 $r$에서 evaluate 해주게 됩니다. 또한 각 게이트에서의 실제 값 $w$ 또한 multilinear extension을 통한 polynomial을 만들어주고, 이들을 결합하여 Sum-Check protocol에 사용할 함수 $f$를 만들어주게 됩니다.
여기서는 다루지 않겠지만 해당 레포지토리에는 Sum-Check protocol또한 구현이 되어 있습니다. (나중에 이 부분도 다뤄보도록 하겠습니다.)
Sum-Check protocol의 마지막 과정은 verifier 가 $f$에 랜덤한 값을 넣고, 이를 prover가 Sum-Check protocol 마지막에 claim 한 값과 비교하는 것입니다. 하지만 $W$ 값들은 circuit 내부의 값이기 때문에 verifier가 직접 계산을 하지 않는 이상 알 수 없는 값입니다. 따라서, prover는 verifier가 $W$ 값을 계산할 수 있도록, reduce_multiple_polynomial 을 사용해 $b$와 $c$값을 interpolate 하는 직선을 만들고, 해당 직선을 $f$ 함수에 합성하여 verifier에게 전달해주게 됩니다.
이러한 trick을 통해 verifier는 $W$ 값들을 몰라도, 단순히 주어진 polynomial 에 $0$과 $1$ 값을 대입함으로 랜덤한 포인트에서의 $f$ 값을 계산할 수 있습니다. ($0$과 $1$을 넣어 verifier가 확인할 수 있는 다항식을 $q$라고 부르고 있습니다.)
f_result = polynomial(f.terms, f.constant)
f_result_value = field.FQ.zero()
for j, x in enumerate(r):
if j == len(r) - 1:
f_result_value = f_result.eval_univariate(x)
f_result = f_result.eval_i(x, j + start_idx)
f_res.append(f_result_value)
r_star = field.FQ(mimc.mimc_hash(list(map(lambda x : int(x), sumcheck_proof[len(sumcheck_proof) - 1]))))
next_r = ell(b_star, c_star, r_star)
z[i + 1] = next_r # r_(i + 1)
r_stars.append(r_star)
그 다음작업은 실제로 prover가 verifier에게 제공하기 위해 실제로 $r$값에서 함수 결과를 계산하는 과정입니다.
또한 그 다음 작업은 fiat-shamir heuristic를 사용하여 다음 layer에서 사용될 새로운 random 값을 만들어내는 과정입니다.
w_input = get_multi_ext(circuit.w_i(circuit.depth() - 1), circuit.k_i(circuit.depth() - 1))
adds = []
mults = []
k = []
for i in range(circuit.depth() - 1):
adds.append(get_multi_ext(circuit.add_i(i), circuit.k_i(i) + 2 * circuit.k_i(i + 1)))
mults.append(get_multi_ext(circuit.mult_i(i), circuit.k_i(i) + 2 * circuit.k_i(i + 1)))
k.append(circuit.k_i(i))
k.append(circuit.k_i(circuit.depth() - 1))
proof = Proof(sumcheck_proofs, sumcheck_r, f_res, D_poly, q, z, r_stars, circuit.depth(), w_input, adds, mults, k)
print("proving time :", time.time() - start_time)
return proof
마지막으로 input layer에서의 w 값까지 계산을 진행한 후, 모든 정보를 proof 에 남아 verifier에게 전달하게 됩니다.
검즘 과정
def verify(proof: Proof):
start = time.time()
m = [field.FQ.zero()]*proof.d
m[0] = eval_expansion(proof.D, proof.z[0])
for i in range(proof.d - 1):
valid = verify_sumcheck(m[i], proof.sumcheck_proofs[i], proof.sumcheck_r[i], 2 * proof.k[i + 1])
if not valid:
return False
else:
q_i = proof.q[i]
q_zero = eval_univariate(q_i, field.FQ.zero())
q_one = eval_univariate(q_i, field.FQ.one())
modified_f = eval_expansion(proof.add[i], proof.z[i] + proof.sumcheck_r[i]) * (q_zero + q_one) \
+ eval_expansion(proof.mult[i], proof.z[i] + proof.sumcheck_r[i]) * (q_zero * q_one)
sumcheck_p = proof.sumcheck_proofs[i]
sumcheck_p_hash = field.FQ(mimc.mimc_hash(list(map(lambda x : int(x), sumcheck_p[len(sumcheck_p) - 1]))))
if (proof.f[i] != modified_f) or (sumcheck_p_hash != proof.r[i]):
print("verifying time :", time.time() - start)
return False
else:
m[i + 1] = eval_univariate(q_i, proof.r[i])
if m[proof.d - 1] != eval_expansion(proof.input_func, proof.z[proof.d - 1]):
print("verifying time :", time.time() - start)
return False
print("verifying time :", time.time() - start)
return True
검증과정 또한 증명과 비슷하게 각 레이어를 돌며 진행됩니다. 우선 Sum-Check protocol이 valid 한지 확인한 후, 이전에 제공된 $q$ 함수를 사용하여 랜덤값 $r$에서 함수 $f$값을 계산합니다. 또한 랜덤한 값에서 계산한 $f$의 결과값이, prover가 제공한 값과 일치하는지 확인합니다. input layer 까지 확인이 끝난 후 검증은 마무리 되게 됩니다.
Conclusion
오늘은 실제로 GKR protocol의 구현체에 대하여 살펴보는 시간을 가졌습니다. Protocol 을 이해하는 것과 이것을 구현해내는 과정에는 큰 차이가 있다는 것을 저 또한 느꼈습니다 :) 어떻게 보면 순서가 잘못되었지만, 다음에는 해당 글을 완벽하게 이해하기 위하여, 이번시간에는 다루지 않은 Sum-Check protocol의 구현 또한 같이 살펴보면 좋겠네요.
긴글 읽어주셔서 감사합니다.
'영지식증명' 카테고리의 다른 글
| [영지식증명] ZK-CNN 파헤치기 - 1 (0) | 2024.04.30 |
|---|---|
| [영지식 증명] Lookup Argument 파헤치기 (0) | 2024.04.23 |
| [영지식증명] GKR protocol 파헤치기 (0) | 2024.03.13 |
| [영지식증명] Sum-Check Protocol 응용: 행렬곱 증명 (1) | 2024.02.15 |
| [영지식증명] Sum-Check Protocol (2) | 2024.02.14 |