Tensors and Strides for NTT

2022 Aug 20 See all posts


Tensors and Strides for NTT

\[ \def\ID{{\mathbf{1}}} \def\FF{{\mathbb{F}}} \def\cols{{\text{cols}}} \def\rows{{\text{rows}}} \def\diag{{\text{diag}}} \def\stride{{\text{stride}}} \def\PP{{\text{P}}} \]

Vectors and matrices

Let \(\FF\) be a field . Given a matrix \(A \in \FF^{m\times n}\), let \(A'\) denote its transpose. Identify (column) vectors \(v\in\FF^n\) with the corresponding \(\FF^{m\times 1}\) matrix, so \(v'\in\FF^{1\times n}\) is a row vector.

Define the columns vector operator stacking the columns of a matrix in a single vector: \[\cols: \FF^{m\times n} \to \FF^{mn}\\ A = (a_{ij}) \mapsto (a_{*,1}', a_{*,2}', \dots a_{*,n}')'\] Similarly, define \(\rows: \FF^{m\times n}\rightarrow\FF^{mn}\) as stacking the (transposed) rows of a matrix. Obviously, \(\rows(A) = \cols(A')\) and \(\cols(A) = \rows(A')\).

Tensor products

The tensor product of \(A\otimes B\in\mathbb{F}^{mp\times nq}\) of \(A\in\mathbb{F}^{m\times n}\) and \(B\in\mathbb{F}^{p\times q}\) is defined by replacing each entry \(a_{ij}\) of \(A\) by the block matrix \(a_{ij}B\). It has several properties: \[(A\otimes B)(C\otimes D) = (AC)\otimes (BD) \tag{1a}\] \[(A\otimes B)' = A'\otimes B' \tag{1b}\]

Roth's Lemma. \[\text{cols}(ABC) = (C'\otimes A)\text{cols}(B) \tag{2}\] Proof: We have \(x\otimes y = (x_1 y_1,x _1 y_2,\dots,x_1 y_n,x_2 y_1,\dots, x_m y_n)' = \text{cols}(yx')\). Using this twice together with \(\text{(1a)}\), we get \[\text{cols}((Ae_i)(e'_jC)) = (C'e_j)\otimes(Ae_i) = (C'\otimes A)(e_j\otimes e_i) =(C'\otimes A)\text{cols}(e_i e'_j).\] Writing \(B=\sum b_{ij}e_ie'_j\), the lemma follows by linearity. QED.

Corollaries are: - \(\cols(AB) = (\ID\otimes A)\cols(B)\quad\;\text{parallel}\) - \(\cols(BC) = (C'\otimes \ID)\cols(B)\quad\,\text{vector}\)

Strides

We use Pythonic notation \(v(i::k) = v(i:(n - 1):k)\) for the vector of elements of \(v=(v_j)_{0\le j<n}\) with \(j\equiv i \pmod k\).

For \(n=ab\) define the \((a,b)\)-stride operator \(\PP_{a,b} = {}_a\stride_b: \FF^n \to \FF^n\) by the equation \[\cols = {}_a\stride_b\circ\rows: \FF^{a\times b}\to\FF^n.\] In words: Transpose the underlying matrices (TODO: draw commutative diagram). In the literature, this is equivalently called the mod-\(a\) perfect shuffle and the mod-\(b\) sort operator. A frequent special case is \({}_a\stride_2\), the even-odd sort operator.

The first name comes from the idea of splitting the entries of a vector in \(a\) piles, and pulling from each pile in turn. The second name is based on how \({}_a\stride_b\) sorts all entries of a vector with the same index mod \(b\) to be consecutive.

Example: \[\PP_{3,5}(x) = (x_0,x_5,x_{10}|x_1,x_6,x_{11}|x_2,x_7,x_{12}|x_3,x_8,x_{14}|x_4,x_9,x_{15})\]

\(P_{a,b}\in\FF^{n\times n}\) can be explicitly constructed as \((b\times a)\) block matrices of size \((a\times b)\), where the block matrix \((j,i)\) is zero except for an entry of \(1\) at position \((i,j)\).

Properties: - \(P_{a,b}' = P_{a,b}^{-1} = P_{b,a}\)

For \(X\in\FF^{n\times n}\) we have: - \(X\circ\PP_{a,b}\) = columns of \(X\) sorted mod \(a\) (3a) - \(\PP_{a,b}\circ X\) = rows of \(X\) sorted mod \(b\) (3b)

Transform

Fix a primitive \(n\)-th root of unity \(r = r_n\) of \(\FF\), assumed to exist. For \(v\in\FF^n\) define the transform \[T_n(v) = \left(\sum_j r_n^{ij} v_j\right)\in\FF^n\tag{4}\] For finite fields this is called the number-theoretic transform. If \(p(x)=\sum_ip_ix^n\in\FF[x]\) is a polynomial of degree at most \(n\), then \(T_n(p)\) is the evaluation of \(p\) at the \(n\)-th roots of unity of \(\FF\), and \(T\) is an isomorphism of the coefficient and values representations of such polynomials. This allows for fast multiplication of large polynomials. For \(\FF = \mathbb{C}\) it is called the Fourier transform and has a long history.

There is a lot of periodicity in the entries \(r_{ij}\), since \(r_n^{ax} = r_b^x\), \(r_n^{by} = r_a^y\) and \(r_n^n = 1\), leading to:

Splitting Lemma.

For any factorization \(n=ab\), we can factor \(T_n\) as product of \((n\times n)\) matrices: \[T_n = (T_a\otimes\ID_b)\diag_{0\le \ell<a}(\Omega_{a,b}^\ell)(\ID_a\otimes T_b)P_{b,a},\] where \(\Omega_{a,b}=\diag(1,r,\dots,r^{b - 1})\in\FF^{b\times b}\), with \(r=r_n\).

Note that \(T_a\otimes\ID_b\) operates on strides with step \(b\) as a vector operation, whereas \(\ID_a\otimes T_b\) operates on consecutive data as a parallel operation.

Proof: We will show that the LHS and RHS are composed of equal \(b\times b\) block matrices.

LHS: Write \(T_nP_{a,b} = (T(:,0::a) | T_n(:, 1::a)|\cdots|T_n(:,a-1::a)) =: (G_{k,\ell})_{0\le k,\ell < a}\) with \(G_{k,\ell}\in\FF^{b\times b}\), using (3a). Hence \(G_{k,\ell} = T_n(kb:(k+1)b - 1, \ell::a)\) has \((i,j)\) entry \[(G_{k,\ell})_{i,j} = r_n^{(kb+i)(\ell+ja)} = r_a^{k\ell}\cdot r_b^{ij}\cdot r_n^{i\ell}\]

RHS: Write \((T_a\otimes\ID_b)\diag_{0\le \ell < a}(\Omega_{a,b}^\ell\cdot T_b) =: (H_{k,\ell})_{0\le k,\ell < a}\). Now \(H_{k,\ell} = r_a^{k\ell}\cdot\Omega_{a,b}^\ell T_b\) has \((i,j)\) entry \[(H_{k,\ell})_{i,j} = r_a^{k,\ell}\cdot(\Omega_{a,b}^\ell)_{i,i}(T_b)_{i,j} = r_a^{k\ell}\cdot r_n^{i\ell}\cdot r_b^{ij}.\] QED.

The operator \(B_{a,b} = (T_a\otimes\ID_b)\diag(\ID_b,\Omega_{a,b},\dots,\Omega_{a,b}^{a - 1})\) is called the radix-a butterfly, and \(D_{a,b}=\diag(\ID_b,\Omega_{a,b},\dots,\Omega_{a,b}^{a - 1})\) the twiddles.