A Neat Not-Randomized Algorithm: Polar Express

Every once in a while, there’s a paper that comes out that is so delightful that I can’t help share it on this blog, and I’ve started a little series Neat Randomized Algorithms for exactly this purpose. Today’s entry into this collection is The Polar Express: Optimal Matrix Sign Methods and their Application to the Muon Algorithm by Noah Amsel, David Persson, Christopher Musco, and Robert M. Gower. Despite its authors belonging to the randomized linear algebra ouvré, this paper is actually about a plain-old deterministic algorithm. But it’s just so delightful that I couldn’t help but share it in this series any way.

The authors of The Polar Express are motivated by the recent Muon algorithm for neural network optimization. The basic idea of Muon is that it helps to orthogonalize the search directions in a stochastic gradient method. That is, rather than update a weight matrix W with search direction G using the update rule

    \[W \gets W - \eta G,\]

instead use the update

    \[W\gets W - \eta \operatorname{polar}(G).\]

Here,

    \[\operatorname{polar}(G) \coloneqq \operatorname*{argmin}_{Q \textrm{ with orthonormal columns}} \norm{G - Q}_{\rm F}\]

is the closed matrix with orthonormal columns to G and is called the (unitary) polar factor of G. (Throughout this post, we shall assume for simplicity that G is tall and full-rank.) Muon relies on efficient algorithms for rapidly approximating \operatorname{polar}(G).

Given a singular value decomposition G = U\Sigma V^\top, the polar factor may be computed in closed form as \operatorname{polar}(G) = UV^\top. But computing the SVD is computationally expensive, particularly in GPU computing environments. Are there more efficient algorithms that avoid the SVD? In particular, can we design algorithms that use only matrix multiplications, for maximum GPU efficiency?

The Polar Factor as a Singular Value Transformation

Computing the polar factor \operatorname{polar}(G) of a matrix G effectively applies an operation to G which replaces all of its singular values by one. Such operations are studied in quantum computing, where they are called singular value transformations.

Definition (singular value transformation): Given an odd function f, the singular value transformation of G = U\Sigma V^\top by f is f[G] \coloneqq Uf(\Sigma)V^\top.

On its face, it might seem like that the polar factor of G is cannot be obtained as a singular value transformation. After all, the constantly one function f(x)= 1 is not odd. But, to obtain the polar factor, we only need a function f which sends positive inputs to 1. Thus, the polar decomposition \operatorname{polar}(G) is given by the singular value transformation associated with the sign function:

    \[\operatorname{sign}(x) = \begin{cases} 1, & x > 0, \\ 0, & x = 0, \\ -1, & x < 0. \end{cases}\]

The sign function is manifestly odd, and the polar factor satisfies

    \[\operatorname{polar}(G) = \operatorname{sign}[G].\]

Singular Value Transformations and Polynomials

How might we go about computing the singular value transformation of a matrix? For an (odd) polynomial, this computation can be accomplished using a sequence of matrix multiplications alone. Indeed, for p(x) = a_1 x + a_3 x^3 + \cdots + a_{2k+1} x^{2k+1}, we have

    \[p[G] = a_1 G + a_3 G(G^\top G) + \cdots + a_{2k+1} G(G^\top G)^k.\]

For a general (odd) function f, we can approximately compute the singular value transformation f[G] by first approximating f by a polynomial p, and then using p[G] as a proxy for f[G]. Here is an example:

>> G = randn(2)                            % Random test matrix
G =
   0.979389080992349  -0.198317114406418
  -0.252310961830649  -1.242378171072736
>> [U,S,V] = svd(G);
>> fG = U*sin(S)*V'                        % Singular value transformation
fG =
   0.824317193982434  -0.167053523352195
  -0.189850719961322  -0.935356030417109
>> pG = G - (G*G'*G)/6 + (G*G'*G*G'*G)/120 % Polynomial approximation
pG =
   0.824508188218982  -0.167091255945116
  -0.190054681059327  -0.936356677704568

We see that we get reasonably high accuracy by approximating \sin[G] using its degree-three Taylor polynomial.

The Power of Composition

The most basic approach to computing the sign function would be to use a fixed polynomial of degree 2k+1. However, this approach converges fairly slowly as we increase the degree k.

A better strategy is to use compositions. A nice feature of the sign function is the fixed point property: For every x, \operatorname{sign}(x) is a fixed point of the \operatorname{sign} function:

    \[\operatorname{sign}(\operatorname{sign}(x)) = \operatorname{sign}(x) \quad \text{for all } x \in \real.\]

The fixed point strategy suggests an alternate strategy for computing the sign function using polynomials. Rather than using one polynomial of large degree, we can instead compose many polynomials of low degree. The simplest such compositional algorithm is the Newton–Schulz iteration, which consists of initializing P\gets G applying the following fixed point equation until convergence:

    \[P \gets \frac{3}{2} P - \frac{1}{2} PP^\top P.\]

Here is an example execution of the algorithm:

>> P = randn(100) / 25;
>> [U,~,V] = svd(P); polar = U*V'; % True polar decomposition
>> for i = 1:20
      P = 1.5*P-0.5*P*P'*P; % Newton-Schulz iteration
      fprintf("Iteration %d\terror %e\n",i,norm(P - polar));
   end
Iteration 1	error 9.961421e-01
Iteration 2	error 9.942132e-01
Iteration 3	error 9.913198e-01
Iteration 4	error 9.869801e-01
Iteration 5	error 9.804712e-01
Iteration 6	error 9.707106e-01
Iteration 7	error 9.560784e-01
Iteration 8	error 9.341600e-01
Iteration 9	error 9.013827e-01
Iteration 10	error 8.525536e-01
Iteration 11	error 7.804331e-01
Iteration 12	error 6.759423e-01
Iteration 13	error 5.309287e-01
Iteration 14	error 3.479974e-01
Iteration 15	error 1.605817e-01
Iteration 16	error 3.660929e-02
Iteration 17	error 1.985827e-03
Iteration 18	error 5.911348e-06
Iteration 19	error 5.241446e-11
Iteration 20	error 6.686995e-15

As we see, the initial rate of convergence is very slow, and obtain only a single digit of accuracy after 15 iterations. After this burn-in period, the rate of convergence is very rapid, and the method achieves machine accuracy after 20 iterations.

The Polar Express

The Newton–Schulz iteration approximates the sign function using a composition of the same polynomial p repeatedly. But we can get better approximations by applying a sequence of different polynomials p_1,\ldots,p_t, resulting in an approximation of the form

    \[\operatorname{sign}[G] \approx p_t[p_{t-1}[\cdots[p_2[p_1[G]]\cdots]].\]

The Polar Express paper asks the question:

What are the optimal choice of polynomials p_i?

For simplicity, the authors of The Polar Express focus on the case where all of the polynomials p_i have the same (odd) degree 2k+1.

On its face, it seems like this problem might be intractable as the best choice of polynomial p_{i+1} seemly could depend in a complicated way on all of the previous polynomials p_1,\ldots,p_i. Fortunately, the authors of The Polar Express show that there is a very simple way of computing the optimal polynomials. Begin by assuming that the singular values of G lie in an interval [\ell_0,u_0]. We then choose p_1 to be the degree-(2k+1) odd polynomial approximation to the sign function on [\ell_0,u_0] that minimizes the L_\infty error:

    \[p_1 = \operatorname*{argmin}_{\text{odd degree-($2k+1$) polynomial } p} \max_{x \in [\ell_0,u_0]} |p(x) - \operatorname{sign}(x)|.\]

This optimal polynomial can be computed by a version of the Remez algorithm provided in the Polar Express paper. After applying p_1 to G, the singular values of p_1[G] lie in a new interval [\ell_1,u_1]. To build the next polynomial p_2, we simply find the optimal approximation to the sign function on this interval:

    \[p_2 = \operatorname*{argmin}_{\text{odd degree-($2k+1$) polynomial } p} \max_{x \in [\ell_1,u_1]} |p(x) - \operatorname{sign}(x)|.\]

Continuing in this way, we can generate as many polynomials p_1,p_2,\ldots as we want.

For given values of \ell_0 and u_0, the coefficients of the optimal polynomials p_1,p_2,\ldots can be computed in advance and stored, allowing for rapid deployment at runtime. Moreover, we can always ensure the upper bound is u_0 = 1 by normalizing G\gets G / \norm{G}_{\rm F}. As such, there is only one parameter \ell_0 that we need to know in order to compute the optimal coefficients. The authors of The Polar Express are motivated by applications in deep learning using 16-bit floating point numbers. In this value, the lower bound \ell_0 = 0.001 is appropriate. (As the authors stress, their method remains convergent even if too large a value of \ell_0 is chosen, though convergence may be slowed somewhat.)

Below, I repeat the experiment from above using (degree-5) Polar Express instead of Newton–Schulz. The coefficients for the optimal polynomials are taken from the Polar Express paper.

>> P = randn(100) / 25;
>> [U,~,V] = svd(P); polar = U*V';
>> P2 = P*P'; P = ((17.300387312530933*P2-23.595886519098837*eye(100))*P2+8.28721201814563*eye(100))*P; fprintf("Iteration 1\terror %e\n",norm(P - polar));
Iteration 1	error 9.921347e-01
>> P2 = P*P'; P = ((0.5448431082926601*P2-2.9478499167379106*eye(100))*P2+4.107059111542203*eye(100))*P; fprintf("Iteration 2\terror %e\n",norm(P - polar));
Iteration 2	error 9.676980e-01
>> P2 = P*P'; P = ((0.5518191394370137*P2-2.908902115962949*eye(100))*P2+3.9486908534822946*eye(100))*P; fprintf("Iteration 3\terror %e\n",norm(P - polar));
Iteration 3	error 8.725474e-01
>> P2 = P*P'; P = ((0.51004894012372*P2-2.488488024314874*eye(100))*P2+3.3184196573706015*eye(100))*P; fprintf("Iteration 4\terror %e\n",norm(P - polar));
Iteration 4	error 5.821937e-01
>> P2 = P*P'; P = ((0.4188073119525673*P2-1.6689039845747493*eye(100))*P2+2.300652019954817*eye(100))*P; fprintf("Iteration 5\terror %e\n",norm(P - polar));
Iteration 5	error 1.551595e-01
>> P2 = P*P'; P = ((0.37680408948524835*P2-1.2679958271945868*eye(100))*P2+1.891301407787398*eye(100))*P; fprintf("Iteration 6\terror %e\n",norm(P - polar));
Iteration 6	error 4.588549e-03
>> P2 = P*P'; P = ((0.3750001645474248*P2-1.2500016453999487*eye(100))*P2+1.8750014808534479*eye(100))*P; fprintf("Iteration 7\terror %e\n",norm(P - polar));
Iteration 7	error 2.286853e-07
>> P2 = P*P'; P = ((0.375*P2-1.25*eye(100))*P2+1.875*eye(100))*P; fprintf("Iteration 8\terror %e\n",norm(P - polar));
Iteration 8	error 1.113148e-14

We see that the Polar Express algorithm converges to machine accuracy in only 8 iterations (24 matrix products), a speedup over the 20 iterations (40 matrix products) required by Newton–Schulz. The Polar Express paper contains further examples with even more significant speedups.

Make sure to check out the Polar Express paper for many details not shared here, including extra tricks to improve stability in 16-bit floating point arithmetic, discussions about how to compute the optimal polynomials, and demonstrations of the Polar Express algorithm for training GPT-2.

References: Muon was first formally described in the blog post Muon: An optimizer for hidden layers in neural networks (2024); for more, see this blog post by Jeremy Bernstein and this paper by Jeremy Bernstein and Laker Newhouse. The Polar Express is proposed in The Polar Express: Optimal Matrix Sign Methods and their Application to the Muon Algorithm (2025) by Noah Amsel, David Persson, Christopher Musco, and Robert M. Gower. For more on the matrix sign function and computing it, chapter 5 of Functions of Matrices: Theory and Computation (2008) by Nicholas H. Higham is an enduringly useful reference.

Leave a Reply

Your email address will not be published. Required fields are marked *