How Good Can Stochastic Trace Estimates Be?

I am excited to share that our paper XTrace: Making the most of every sample in stochastic trace estimation has been published in the SIAM Journal on Matrix Analysis and Applications. (See also our paper on arXiv.)

Spurred by this exciting news, I wanted to take the opportunity to share one of my favorite results in randomized numerical linear algebra: a “speed limit” result of Meyer, Musco, Musco, and Woodruff that establishes a fundamental limitation on how accurate any trace estimation algorithm can be.

Let’s back up. Given an unknown square matrix A, the trace of A, defined to be the sum of its diagonal entries

    \[\tr(A) \coloneqq \sum_{i=1}^n A_{ii}.\]

The catch? We assume that we can only access the matrix A through matrix–vector products (affectionately known as “matvecs”): Given any vector x, we have access to Ax. Our goal is to form an estimate \hat{\tr} that is as accurate as possible while using as few matvecs as we can get away with.

To simplify things, let’s assume the matrix A is symmetric and positive (semi)definite. The classical algorithm for trace estimation is due to Girard and Hutchinson, producing a probabilistic estimate \hat{\tr} with a small average (relative) error:

    \[\expect\left[\frac{|\hat{\tr}-\tr(A)|}{\tr(A)}\right] \le \varepsilon \quad \text{using } m= \frac{\rm const}{\varepsilon^2} \text{ matvecs}.\]

If one wants high accuracy, this algorithm is expensive. To achieve just a 1% error (\varepsilon=0.01) requires roughly m=10,\!000 matvecs!

This state of affairs was greatly improved by Meyer, Musco, Musco, and Woodruff. Building upon previous work, they proposed the Hutch++ algorithm and proved it outputs an estimate \hat{\tr} satisfying the following bound:

(1)   \[\expect\left[\frac{|\hat{\tr}-\tr(A)|}{\tr(A)}\right] \le \varepsilon \quad \text{using } m= \frac{\rm const}{\varepsilon} \text{ matvecs}.\]

Now, we only require roughly m=100 matvecs to achieve 1% error! Our algorithm, XTrace, satisfies the same error guarantee (1) as Hutch++. On certain problems, XTrace can be quite a bit more accurate than Hutch++.

The MMMW Trace Estimation “Speed Limit”

Given the dramatic improvement of Hutch++ and XTrace over Girard–Hutchinson, it is natural to hope: Is there an algorithm that does even better than Hutch++ and XTrace? For instance, is there an algorithm satisfying an even slightly better error bound of the form

    \[\expect\left[\frac{|\hat{\tr}-\tr(A)|}{\tr(A)}\right] \le \varepsilon \quad \text{using } m= \frac{\rm const}{\varepsilon^{0.999}} \text{ matvecs}?\]

Unfortunately not. Hutch++ and XTrace are essentially as good as it gets.

Let’s add some fine print. Consider an algorithm for the trace estimation problem. Whenever the algorithm wants, it can present a vector x_i and receive back Ax_i. The algorithm is allowed to be adaptive: It can use the matvecs Ax_1,\ldots,Ax_s it has already collected to decide which vector x_{s+1} to present next. We measure the cost of the algorithm in terms of the number of matvecs alone, and the algorithm knows nothing about the psd matrix A other what it learns from matvecs.

One final stipulation:

Simple entries assumption. We assume that the entries of the vectors x_i presented by the algorithm are real numbers between -1 and 1 with up to b digits after the decimal place.

To get a feel for this simple entries assumption, suppose we set b=2. Then (-0.92,0.17) would be an allowed input vector, but (0.232,-0.125) would not be (too many digits after the decimal place). Similarly, (18.3,2.4) would not be valid because its entries exceed 1. The simple entries assumption is reasonable as we typically represent numbers on digital computers by storing a fixed number of digits of accuracy.1We typically represent numbers on digital computers by floating point numbers, which essentially represent numbers using scientific notation like 1.3278123 \times 10^{27}. For this analysis of trace estimation, we use fixed point numbers like 0.23218 (no powers of ten allowed)!

With all these stipulations, we are ready to state the “speed limit” for trace estimation proved by Meyer, Musco, Musco, and Woodruff:

Informal theorem (Meyer, Musco, Musco, Woodruff). Under the assumptions above, there is no trace estimation algorithm producing an estimate \hat{\tr} satisfying

    \[\expect\left[\frac{|\hat{\tr}-\tr(A)|}{\tr(A)}\right] \le \varepsilon \quad \text{using } m= \frac{\rm const}{\varepsilon^{0.999}} \text{ matvecs}.\]

We will see a slightly sharper version of the theorem below, but this statement captures the essence of the result.

Communication Complexity

To prove the MMMW theorem, we have to take a journey to the beautiful subject of communication complexity. The story is this. Alice and Bob are interested in solving a computational problem together. Alice has her input x and Bob has his input y, and they are interested in computing a function f(x,y) of both their inputs.

Unfortunately for the two of them, Alice and Bob are separated by a great distance, and can only communicate by sending single bits (0 or 1) of information over a slow network connection. Every bit of communication is costly. The field of communication complexity is dedicated to determining how efficiently Alice and Bob are able to solve problems of this form.

The Gap-Hamming problem is one example of a problem studied in communication complexity. As inputs, Alice and Bob receive vectors x,y \in \{\pm 1\}^n with +1 and -1 entries from a third party Eve. Eve promises Alice and Bob that their vectors x and y satisfy one of two conditions:

(2)   \[\text{Case 0: } x^\top y \ge\sqrt{n} \quad \text{or} \quad \text{Case 1: } x^\top y \le -\sqrt{n}. \]

Alice and Bob must work together, sending as few bits of communication as possible, to determine which case they are in.

There’s one simple solution to this problem: First, Bob sends his whole input vector y to Alice. Each entry of y takes one of the two value \pm 1 and can therefore be communicated in a single bit. Having received y, Alice computes x^\top y, determines whether they are in case 0 or case 1, and sends Bob a single bit to communicate the answer. This procedure requires n+1 bits of communication.

Can Alice and Bob still solve this problem with many fewer than n bits of communication, say \sqrt{n} bits? Unfortunately not. The following theorem of Chakrabati and Regev shows that roughly n bits of communication are needed to solve this problem:

Theorem (Chakrabati–Regev). Any algorithm which solves the Gap-Hamming problem that succeeds with at least 2/3 probability for every pair of inputs x and y (satisfying one of the conditions (2)) must take \Omega(n) bits of communication.

Here, \Omega(n) is big-Omega notation, closely related to big-O notation \order(n) and big-Theta notation \Theta(n). For the less familiar, it can be helpful to interpret \Omega(n), \order(n), and \Theta(n) as all standing for “proportional to n”. In plain language, the theorem of Chakrabati and Regev result states that there is no algorithm for the Gap-Hamming problem that much more effective than the basic algorithm where Bob sends his whole input to Alice (in the sense of requiring less than \order(n) bits of communication).

Reducing Gap-Hamming to Trace Estimation

This whole state of affairs is very sad for Alice and Bob, but what does it have to do with trace estimation? Remarkably, we can use hardness of the Gap-Hamming problem to show there’s no algorithm that fundamentally improves on Hutch++ and XTrace. The argument goes something like this:

  1. If there were a trace estimation algorithm fundamentally better than Hutch++ and XTrace, we could use it to solve Gap-Hamming in fewer than \order(n) bits of communication.
  2. But no algorithm can solve Gap-Hamming in fewer than \order(n) bits or communication.
  3. Therefore, no trace estimation algorithm is fundamentally better than Hutch++ and XTrace.

Step 2 is the work of Chakrabati and Regev, and step 3 follows logically from 1 and 2. Therefore, we are left to complete step 1 of the argument.

Protocol

Assume we have access to a really good trace estimation algorithm. We will use it to solve the Gap-Hamming problem. For simplicity, assume n is a perfect square. The basic idea is this:

  • Have Alice and Bob reshape their inputs x,y \in \{\pm 1\}^n into matrices X,Y\in\{\pm 1\}^{\sqrt{n}\times \sqrt{n}}, and consider (but do not form!) the positive semidefinite matrix

        \[A = (X+Y)^\top (X+Y).\]

  • Observe that

        \[\tr(A) = \tr(X^\top X) + 2\tr(X^\top Y) + \tr(Y^\top Y) = 2n + 2(x^\top y).\]

    Thus, the two cases in (2) can be equivalently written in terms of \tr(A):

    (2′)   \[\text{Case 0: } \tr(A)\ge 2n + 2\sqrt{n} \quad \text{or} \quad \text{Case 1: } \tr(A) \le 2n-2\sqrt{n}. \]

  • By working together, Alice and Bob can implement a trace estimation algorithm. Alice will be in charge of running the algorithm, but Alice and Bob must work together to compute matvecs. (Details below!)
  • Using the output of the trace estimation algorithm, Alice determines whether they are in case 0 or 1 (i.e., where \tr(A) \gg 2n or \tr(A) \ll 2n) and sends the result to Bob.

To complete this procedure, we just need to show how Alice and Bob can implement the matvec procedure using minimal communication. Suppose Alice and Bob want to compute Az for some vector z with entries between -1 and 1 with up to b decimal digits. First, convert z to a vector w\coloneqq 10^b z whose entries are integers between -10^b and 10^b. Since Az = 10^{-b}Aw, interconverting between Az and Aw is trivial. Alice and Bob’s procedure for computing Aw is as follows:

  • Alice sends Bob w.
  • Having received w, Bob forms Yw and sends it to Alice.
  • Having received Yw, Alice computes v\coloneqq Xw+Yw and sends it to Bob.
  • Having received v, Bob computes Y^\top v and sends its to Alice.
  • Alice forms Aw = X^\top v + Y^\top v.

Because X and Y are \sqrt{n}\times \sqrt{n} and have \pm 1 entries, all vectors computed in this procedure are vectors of length \sqrt{n} with integer entries between -4n 10^b and 4n10^b. We conclude the communication cost for one matvec is T\coloneqq\Theta((b+\log n)\sqrt{n}) bits.

Analysis

Consider an algorithm we’ll call BestTraceAlgorithm. Given any accuracy parameter \varepsilon > 0, BestTraceAlgorithm requires at most m = m(\varepsilon) matvecs and, for any positive semidefinite input matrix A of any size, produces an estimate \hat{\tr} satisfying

(3)   \[\expect\left[\frac{|\hat{\tr}-\tr(A)|}{\tr(A)}\right] \le \varepsilon.\]

We assume that BestTraceAlgorithm is the best possible algorithm in the sense that no algorithm can achieve (3) on all (positive semidefinite) inputs with m' < m matvecs.

To solve the Gap-Hamming problem, Alice and Bob just need enough accuracy in their trace estimation to distinguish between cases 0 and 1. In particular, if

    \[\left| \frac{\hat{\tr} - \tr(A)}{\tr(A)} \right| \le \frac{1}{\sqrt{n}},\]

then Alice and Bob can distinguish between cases 0 and 1 in (2′)

Suppose that Alice and Bob apply trace estimation to solve the Gap-Hamming problem, using m matvecs in total. The total communication is m\cdot T = \order(m(b+\log n)\sqrt{n}) bits. Chakrabati and Regev showed that Gap-Hamming requires cn bits of communication (for some c>0) to solve the Gap-Hamming problem with 2/3 probability. Thus, if m\cdot T < cn, then Alice and Bob fail to solve the Gap-Hamming problem with at least 1/3 probability. Thus,

    \[\text{If } m < \frac{cn}{T} = \Theta\left( \frac{\sqrt{n}}{b+\log n} \right), \quad \text{then } \left| \frac{\hat{\tr} - \tr(A)}{\tr(A)} \right| > \frac{1}{\sqrt{n}} \text{ with probability at least } \frac{1}{3}.\]

The contrapositive of this statement is that if

    \[\text{If }\left| \frac{\hat{\tr} - \tr(A)}{\tr(A)} \right| \le \frac{1}{\sqrt{n}}\text{ with probability at least } \frac{2}{3}, \quad \text{then } m \ge \Theta\left( \frac{\sqrt{n}}{b+\log n} \right).\]


Say Alice and Bob run BestTraceAlgorithm with parameter \varepsilon = \tfrac{1}{3\sqrt{n}}. Then, by (3) and Markov’s inequality,

    \[\left| \frac{\hat{\tr} - \tr(A)}{\tr(A)} \right| \le \frac{1}{\sqrt{n}} \quad \text{with probability at least }\frac{2}{3}.\]

Therefore, BestTraceAlgorithm requires at least

    \[m \ge \Theta\left( \frac{\sqrt{n}}{b+\log n} \right) \text{ matvecs}.\]

Using the fact that we’ve set \varepsilon = 1/3\sqrt{n}, we conclude that any trace estimation algorithm, even BestTraceAlgorithm, requires

    \[m \ge \Theta \left( \frac{1}{\varepsilon (b+\log(1/\varepsilon))} \right) \text{ matvecs}.\]

In particular, no trace estimation algorithm can achieve mean relative error \varepsilon using even \order(1/\varepsilon^{0.999}) matvecs. This proves the MMMW theorem.

Leave a Reply

Your email address will not be published. Required fields are marked *