영지식증명

[영지식증명] ZK-CNN 파헤치기 - 2

녹차뽀드득 2024. 5. 2. 10:41

Introduction

안녕하세요, 저번 시간부터 저희는 prover와 verifier가 CNN model weight 을 공유하지 않고도 실제로 input에 대하여 올바른 연산을 진행하였음을 증명할 수 있는 ZK-CNN protocol에 대하여 알아보기 시작했습니다. 

좀 더 구체적으로, 지난 시간 우리는 Fast Fourier Transform (FFT) 에 대한 새로운 Sumcheck Protocol에 대하여 알아보았습니다. 지난 시간 일련의 과정은 CNN 과는 어느정도 거리가 있는 내용으로 느껴지셨을 수 있는데요, 이번 시간에는 지난번에 알아본 FFT 에 대한 Sumcheck protocol 을 가지고, 어떻게 2D convolution을 증명할 수 있는지 알아보도록 하겠습니다.

Proving 2D convolution using FFT

- Inverse FFT.

우리가 일전에 살펴본 FFT의 역과정인 Inverse FFT 역시, 다른 root of unity 를 가진 FFT로 볼 수 있습니다. 이를 식으로 작성해보면: 

$a_j = \sum^{N-1}_{i=0}c_i \omega^{ji} \leftrightarrow \frac{1}{M}\sum^{M-1}_{j=0}a_j \omega^{-ji}$

처럼 나타낼 수 있습니다. 따라서 이 경우 (IFFT) 에도 동일하게 이전 시간에 소개된 FFT 를 위한 sumcheck protocol을 사용할 수 있습니다. 

 

- 2D convolution to 1D convolution. 

CNN 에 널리 사용되는 2D convolution의 경우 다음과 같은 식으로 쓰여질 수 있습니다. 

$U_{j, k} = \sum^{w-1, w-1}_{t=0, l=0}X_{j+t, k+l} \cdot W_{t, l}$

(이번 글에서 2D convolution의 역할, 혹은 수식에 대하여 깊게 다루지는 않도록 하겠습니다) 다음으로, $\tilde{X}, \tilde{W} \in \mathbf{F}^{n^2}$를 아래와 같이 정의하도록 하겠습니다. 

$\tilde{X}_{tn+1} = X_{n-1-t, n-1-l}, 0 \leq t < n, 0 \leq l < n$

$\tilde{W}_{tn+1} = W_{t, l} \text{ if } 0 \leq t, l < w,\text{ 0 otherwise }$

$\tilde{U}_j = \sum^j_{i=0}\tilde{X}_{j-i}\tilde{W}_i$

해당 식을 활용하면, (자세한 유도를 적지는 않겠습니다) 다음과 같이 $U_{j, k}$ 를 정리할 수 있습니다. 

$U_{j, k} = \tilde{U}_{n^2 - 1 - j \cdot n - k}$

식을 조금만 살펴보면, 사실 이것은 1D convolution의 형태임을 알 수 있습니다. 적절한 indexing을 통해 2D convolution을 1D convolution으로 변환한 것입니다. 

- Proving 1D convolution using FFT

1D convolution 연산의 경우 두 univariate 한 polynomial 간의 곱셈과 매우 유사한 형태를 띄고 있습니다. Convolution이 dot product 의 연속으로 이루어진 만큼, 이것을 적절히 linearize 한다면, polynomial의 곱셈의 형태로 생각할 수 있습니다. 만약 $\tilde{X}, \tilde{W}$의 coefficient를 $\tilde{X}(\eta), \tilde{W}(\eta)$ 라고한다면, $\tilde{U}(\eta) = \tilde{X}(\eta)\tilde{W}(\eta) \leftrightarrow \tilde{U}_j = \sum^j_{i=0}\tilde{X}_{j-i}\tilde{W}_i$ 라고도 작성할 수 있습니다. ($\tilde{U}$ 는 $\tilde{U}(\eta)$ 의 첫 $n^2$ coefficient 입니다.)

 

이제 여기까지 봤다면, 마지막으로 polynomial 간의 곱셈은 FFT와 IFFT를 사용해 세 가지 스텝으로 증명할 수 있습니다. 우선 첫번째로 $\tilde{X}(\eta), \tilde{W}(\eta)$ coefficient들을 $FFT(\tilde{X}), FFT(\tilde{W})$를 통해 root of unity 에서의 연산결과로 변환합니다. 그 후, 이들간의 hadamard product (element-wise dot product)를 계산한 후, 이 결과값을 다시 IFFT연산을 통해 coefficient 로 원복시켜주게 됩니다. 즉, 

$\tilde{U} = \tilde{X} \times \tilde{W} = IFFT(FFT(\tilde{X}) \cdot FFT(\tilde{W}))$ 

임을 보임으로, 1D convolution을 앞서 제시한 sumcheck protocol 연산들로 나타낼 수 있고, 해당 1D convolution은 우리가 증명하고자 하는 2D convolution과 동치이므로, 2D convolution 연산을 증명하는 셈이 됩니다. 

 

Conclusion

이번 시간에는, 지난번 제안한 FFT에 대한 새로운 sumcheck protocol을 사용해 어떻게 2D convolution operation을 증명할 수 있는지에 대하여 알아보았습니다. 지금까지 2D convolution을 증명할 수 있는 방법을 알아봤지만, 전체 CNN 구조는 이보다 더 복잡합니다. 다양한 덧셈과 곱셈, 그리고 ReLU activation등 까다로운 연산이 들어가 있습니다. 이에 따라 해당 논문에서는 여러가지 다양한 변형된 protocol적 디테일을 제시하는데요, 다음 시간에는 이 내용들에 대하여 살펴보는 시간을 갖도록 하겠습니다. 읽어주셔서 감사합니다.