A corner of design space of AI accelerators that's gaining a little more attention recently, at least in physics departments, are "intelligent metamaterials" - I studied one of these for a master's thesis, effectively the idea is that you have a sparse-grid of elements (eg. a grid) and these respond to perturbations or vibrations and create a quasi-static equilibrium response in either state or frequency space. Training here typically works either using an external simulator or using fancy physical algorithms like equilibrium propagation.
The one nagging question I had when I was done reading all of this was "how do we get this to work outside of equilibrium" which is where the fun stuff like language modelling lives - and better yet, is there a local message-passing algorithm which makes it possible for a sparse system to learn through time.
It turns out one does exist and you can do row-wise sequential MNIST at about \(5\times\) chance with 1,120 learnable scalars arranged in a rectangular grid of 280 nodes, trained online with batch size 1 using plain SGD and zero per-parameter optimiser state. Here's how.
The network is an \(N\times M\) rectangular lattice. Each node \((i,j)\) is the same tiny first-order IIR filter with a \(\tanh\) non-linearity: \[s_{ij}(t{+}1) = a_{ij}\, s_{ij}(t) + c_{ij}\, \big(w^{\text{L}}_{ij}\, y_{i,j-1}(t) + w^{\text{T}}_{ij}\, y_{i-1,j}(t)\big)\] \[y_{ij}(t{+}1) = \tanh\!\big(s_{ij}(t{+}1) + b_{ij}\big)\] where \(a_{ij} = e^{-\Delta t/\tau_{ij}}\) and \(c_{ij} = 1 - a_{ij}\) are the usual zero-order-hold coefficients of a leaky integrator with time constant \(\tau_{ij}\). Each cell has exactly four learnable scalars: \(\{\log\tau_{ij},\, w^{\text{L}}_{ij},\, w^{\text{T}}_{ij},\, b_{ij}\}\). Every node has two inputs (from its left and above neighbours), and its single output is routed simultaneously to its right and below neighbours - hence "two-in, two-out".
The four edges of the grid (the shorelines) do something specific:
For row-wise MNIST that makes the grid \(28\times 10\): 28 left-shoreline rows (one per input pixel channel), 10 bottom-shoreline cells (one per class). At each of \(T{=}28\) timesteps, one row of the MNIST image arrives at the left shoreline, signal sloshes diagonally across the lattice through the tanh-filter cascades, and the bottom row produces a 10-vector that we sum over time before the usual softmax cross-entropy loss. The encoder is \(I_{28}\) by geometry, the decoder is the identity followed by \(\sum_t\). All representational work happens inside the grid's 1,120 parameters.
For the learning rule I adapted the SnAP-\(k\) family to this geometry. The short version: each cell maintains eligibility traces that accumulate its parameters' influence on its own filter state forward in time, and at each timestep we combine those traces with an incoming spatial message from the readout to get a local gradient estimate. The reason this is tractable on a grid is the same reason it was tractable on the IIR dynamics I wrote about earlier: the eligibility traces inherit the filter's recursion. For a cell's own weights: \[e^{\text{L}}_{ij}(t{+}1) = a_{ij}\, e^{\text{L}}_{ij}(t) + c_{ij}\, y_{i,j-1}(t{+}1)\] and symmetrically for \(e^{\text{T}}_{ij}\) (the from-above weight) and \(\eta_{ij}\) (the time-constant). The bias doesn't need a trace because \(\partial y/\partial b = 1 - y^2\) is instantaneous.
So far, so standard. The grid-specific part is what happens at cross-cell sensitivities. On a tree each cell has exactly one descendant per depth level - an ancestor chain - and SnAP-\(k\) just tracks a small chain of cross-traces. On a grid each cell has two immediate descendants (the one to its right and the one below), so SnAP-1 has to track two cross-traces per parameter type. For example, for the left-weight \(w^{\text{L}}_{ij}\): \[\epsilon^{(1,\text{below})}_{ij,\,w^{\text{L}}}(t{+}1) = a_{\text{below}}\,\epsilon^{(1,\text{below})}_{ij,\,w^{\text{L}}}(t) + c_{\text{below}}\, w^{\text{T}}_{\text{below}}\, (1-y_{ij}^2)\, e^{\text{L}}_{ij}(t{+}1)\] and similarly for the right descendant, which couples through \(w^{\text{L}}_{\text{right}}\) instead. Total per-cell state at training time: three self-traces plus eight cross-traces, or eleven scalars on top of the four parameters. Forward pass is a raster sweep. Backward pass is a reverse raster sweep of spatial messages within a timestep - messages only propagate spatially, not across time; the temporal history lives entirely in those eleven cached scalars.
Summing the per-timestep local gradients over \(T\) timesteps gives the full sequence gradient in the usual RTRL way. This is well-known truncation: SnAP-1 only captures one descendant hop and ignores the rest of the downstream cone, so the gradient estimate is biased. But the sign is right in expectation and the math all fits in a few dozen lines of JAX.
The hardware-realistic configuration is the one that matters: plain per-group SGD, batch size 1, zero per-parameter optimiser state. Per-group because the gradients have pretty different scales for weights, biases, and \(\log\tau\) - I used \((\eta_w,\, \eta_b,\, \eta_\tau) = (\eta_0,\, \eta_0/20,\, 4\eta_0)\) with \(\eta_0 = 10^{-2}\), same ratios I'd tuned for the tree variant. This is the minimum configuration a chip would actually implement: no Adam, no batching, one gradient step per training example.
Training trajectory on full-scale row-wise MNIST (60,000 training images, random-chance is 0.10):
| epoch | train acc | val acc |
|---|---|---|
| 0 | 0.158 | 0.266 |
| 1 | 0.273 | 0.283 |
| 2 | 0.304 | 0.320 |
Still climbing at 1-2 points per epoch when I cut it off; would probably settle somewhere around 0.4-0.5 given another 20 epochs. If we let ourselves cheat a little and swap in Adam with batch size 32 - still no learned encoder or decoder, still SnAP-1, still the same 1,120 parameters - we get 0.522 val acc after 10 epochs. Neither number is remotely competitive with a linear classifier on raw pixels (which hits \(\sim\!0.92\) with about seven times the parameters), but that's not the point. The point is that the rule works, the architecture trains, and the whole thing runs with the same local arithmetic a chip could implement cell-by-cell.
I was pretty pleased to finally find something that works, and I'm a little surprised it does - so I got ahead of myself and I asked some people in the field who study the equilibrium case to take a look at this result. The obvious first point, of course, this whole architecture is very sensitive to vanishing gradients and any noise in that backward message pass (not the traces incidentally) will in turn break the algorithm. This finding is a pretty straightforward answer to the question of if we've found something goofy which will survive translation into hardware: no.
The standard fixes to this are really, residual lines or small world connectivity, which brings you full circle to *the point* of even attempting this which is the cellular automata-style connection pattern. And once you realise these are inevitable for anything that will have stable gradients, you design what amounts to a fancy cross-bar with first-order response at each terminus instead of an exotic resistor element... in which case, we might as well just move those filters to the terminus... and what do you know, it's a diagonal RNN running SnAp-1 ;)
It was a worthwhile exercise in futility to get this learning and I learnt a lot about the vivo-silico gap, what non-equilibrium response learning *really* means for anyone foolhardy enough to try and tape it out. My ultimate conclusion though is, I just want a diagonal-LTC RNN accelerator with SnAp-1 traces parked next to my weights that I can clock backwards, which is what I will be doing next, probably.
Some other interesting ablations that came along is that changing the grid pattern into a balance binary tree pattern actually helps performance substantively (~52% -> 70%) and you can look at the stuff that I wrote along with it, here. There is an auto-slopped accompanying paper and if I'm sufficiently motivated by friends, I will try to get indexe on the arXiv, at the moment - this is a PITA.