A Grid of 1,120 Analog Filters Can Learn MNIST

If you want a neural network that runs on an analog crossbar, there's a short list of things you can't have: no global backward pass, no per-parameter adaptive optimiser state, no mini-batch averaging of gradients, and - critically - no learned input projection if your input wiring is baked into silicon at fab time. What's left when you take all of those away is surprisingly small: a grid of identical analog cells doing local arithmetic, with some local learning rule updating each cell's handful of knobs using only signals physically present at that cell. The question I wanted to answer was: does that actually learn anything, or does it just sit there making noise?

Turns out you can do row-wise sequential MNIST at about \(5\times\) chance with 1,120 learnable scalars arranged in a rectangular grid of 280 nodes, trained online with batch size 1 using plain SGD and zero per-parameter optimiser state. Here's how.

The Architecture

The network is an \(N\times M\) rectangular lattice. Each node \((i,j)\) is the same tiny first-order IIR filter with a \(\tanh\) non-linearity: \[s_{ij}(t{+}1) = a_{ij}\, s_{ij}(t) + c_{ij}\, \big(w^{\text{L}}_{ij}\, y_{i,j-1}(t) + w^{\text{T}}_{ij}\, y_{i-1,j}(t)\big)\] \[y_{ij}(t{+}1) = \tanh\!\big(s_{ij}(t{+}1) + b_{ij}\big)\] where \(a_{ij} = e^{-\Delta t/\tau_{ij}}\) and \(c_{ij} = 1 - a_{ij}\) are the usual zero-order-hold coefficients of a leaky integrator with time constant \(\tau_{ij}\). Each cell has exactly four learnable scalars: \(\{\log\tau_{ij},\, w^{\text{L}}_{ij},\, w^{\text{T}}_{ij},\, b_{ij}\}\). Every node has two inputs (from its left and above neighbours), and its single output is routed simultaneously to its right and below neighbours - hence "two-in, two-out".

The four edges of the grid (the shorelines) do something specific:

For row-wise MNIST that makes the grid \(28\times 10\): 28 left-shoreline rows (one per input pixel channel), 10 bottom-shoreline cells (one per class). At each of \(T{=}28\) timesteps, one row of the MNIST image arrives at the left shoreline, signal sloshes diagonally across the lattice through the tanh-filter cascades, and the bottom row produces a 10-vector that we sum over time before the usual softmax cross-entropy loss. The encoder is \(I_{28}\) by geometry, the decoder is the identity followed by \(\sum_t\). All representational work happens inside the grid's 1,120 parameters.

Learning: SnAP-1 on a 2D Lattice

For the learning rule I adapted the SnAP-\(k\) family to this geometry. The short version: each cell maintains eligibility traces that accumulate its parameters' influence on its own filter state forward in time, and at each timestep we combine those traces with an incoming spatial message from the readout to get a local gradient estimate. The reason this is tractable on a grid is the same reason it was tractable on the IIR dynamics I wrote about earlier: the eligibility traces inherit the filter's recursion. For a cell's own weights: \[e^{\text{L}}_{ij}(t{+}1) = a_{ij}\, e^{\text{L}}_{ij}(t) + c_{ij}\, y_{i,j-1}(t{+}1)\] and symmetrically for \(e^{\text{T}}_{ij}\) (the from-above weight) and \(\eta_{ij}\) (the time-constant). The bias doesn't need a trace because \(\partial y/\partial b = 1 - y^2\) is instantaneous.

So far, so standard. The grid-specific part is what happens at cross-cell sensitivities. On a tree each cell has exactly one descendant per depth level - an ancestor chain - and SnAP-\(k\) just tracks a small chain of cross-traces. On a grid each cell has two immediate descendants (the one to its right and the one below), so SnAP-1 has to track two cross-traces per parameter type. For example, for the left-weight \(w^{\text{L}}_{ij}\): \[\epsilon^{(1,\text{below})}_{ij,\,w^{\text{L}}}(t{+}1) = a_{\text{below}}\,\epsilon^{(1,\text{below})}_{ij,\,w^{\text{L}}}(t) + c_{\text{below}}\, w^{\text{T}}_{\text{below}}\, (1-y_{ij}^2)\, e^{\text{L}}_{ij}(t{+}1)\] and similarly for the right descendant, which couples through \(w^{\text{L}}_{\text{right}}\) instead. Total per-cell state at training time: three self-traces plus eight cross-traces, or eleven scalars on top of the four parameters. Forward pass is a raster sweep. Backward pass is a reverse raster sweep of spatial messages within a timestep - messages only propagate spatially, not across time; the temporal history lives entirely in those eleven cached scalars.

Summing the per-timestep local gradients over \(T\) timesteps gives the full sequence gradient in the usual RTRL way. This is well-known truncation: SnAP-1 only captures one descendant hop and ignores the rest of the downstream cone, so the gradient estimate is biased. But the sign is right in expectation and the math all fits in a few dozen lines of JAX.

Does It Actually Learn Anything?

The hardware-realistic configuration is the one that matters: plain per-group SGD, batch size 1, zero per-parameter optimiser state. Per-group because the gradients have pretty different scales for weights, biases, and \(\log\tau\) - I used \((\eta_w,\, \eta_b,\, \eta_\tau) = (\eta_0,\, \eta_0/20,\, 4\eta_0)\) with \(\eta_0 = 10^{-2}\), same ratios I'd tuned for the tree variant. This is the minimum configuration a chip would actually implement: no Adam, no batching, one gradient step per training example.

Training trajectory on full-scale row-wise MNIST (60,000 training images, random-chance is 0.10):

epochtrain accval acc
00.1580.266
10.2730.283
20.3040.320

Still climbing at 1-2 points per epoch when I cut it off; would probably settle somewhere around 0.4-0.5 given another 20 epochs. If we let ourselves cheat a little and swap in Adam with batch size 32 - still no learned encoder or decoder, still SnAP-1, still the same 1,120 parameters - we get 0.522 val acc after 10 epochs. Neither number is remotely competitive with a linear classifier on raw pixels (which hits \(\sim\!0.92\) with about seven times the parameters), but that's not the point. The point is that the rule works, the architecture trains, and the whole thing runs with the same local arithmetic a chip could implement cell-by-cell.

What This Is, And What It Isn't

I want to be upfront that getting 0.5-ish on MNIST is a low bar. MNIST is the universal signature of "you built a working gradient pipeline" - it falls out of almost any vaguely-RNN-shaped thing trained with almost any not-insane rule. The 0.522 doesn't say anything about the grid topology being good at visual tasks - row-wise MNIST isn't really a visual task, it's "integrate some bits over time and vote at the end" - and it certainly isn't a claim to be competing with transformers.

What it is, I think, is an existence proof: you can take a 2D raster of identical analog cells, bolt a reasonable local learning rule onto them, hand the thing a sequence of bits, and the weights converge to something that does better than chance on a real dataset. No hidden tricks: no conv layers, no residual connections, no learned readout, no batching, no momentum. Just a grid of tanh filters with neighbour-only connectivity, learning by multiplying a backward message against a cached eligibility trace at each cell. If you were going to build this thing in silicon - a 28-by-10 tile of OTAs with capacitors, differential pairs for the tanh, two programmable resistors per cell - the learning dynamics would unfold on the chip, not in a digital controller, because every signal the update rule needs is already there.

Code and the paper draft are up at github.com/ThomasPluck/grid-neurons, alongside its tree-topology sibling which has more thorough characterisation. The next obvious thing to try is SHD (the spoken-digit benchmark that actually needs temporal credit assignment over a longer horizon), followed by swapping the idealised JAX forward pass for a simulated-crossbar one with realistic device noise and a finite-precision weight grid. If any of that goes anywhere interesting I'll write it up.