The neuronal credit assignment problem as causal inference


Ben Lansdell, Bioengineering UPenn

Drexel University. February 27th 2020

Two complementary tasks to understand intelligence

1. Build artificial systems


Learning is central to both human and artificial intelligence


$\Rightarrow$ Advances in each domain can inspire the other

Machine learning, neuroscience, and causality

Causation relates to a number of challenges in both machine learning and neuroscience

Messerli, N Engl J Med 2012

Machine learning, neuroscience, and causality

Causation relates to a number of challenges in both machine learning and neuroscience

Machine learning, neuroscience, and causality

Causation relates to a number of challenges in both machine learning and neuroscience


Claim: progress in both machine learning and neuroscience can come from explicitly casting problems as causal learning problems

Outline

  1. The neuronal credit assignment problem as causal inference
  2. Learning to solve the credit assignment problem

Learning in the brain

Learning in the brain

Find parameters $w$ that minimize a loss/maximize a reward function, $R$

Learning in the brain











What are the synaptic update rules used by neurons that provide efficient and flexible learning?

Must:

  • Be consistent with known neurophysiology
  • Be good enough at learning complicated tasks

The neuronal credit assignment problem

To learn, a neuron must know its effect on the reward function



In spiking neural networks, this means something like:

  • If, for a given input, a spike increases the reward, the weights leading to that spike should increase
  • If, for a given input, a spike decreases the reward, the weights leading to that spike should decrease

The problem: noise correlations and confounding

$\Rightarrow$ Viewing learning as a causal inference problem may provide insight

Causality

  • Defined in terms of counterfactuals or interventions
  • The causal effect: $\beta = \mathbb{E}(R|H\leftarrow 1) - \mathbb{E}(R|H\leftarrow 0)$
  • How can we predict the causal effect from observation?

Causality

  • Defined in terms of counterfactuals or interventions
  • The causal effect: $\beta = \mathbb{E}(R|H\leftarrow 1) - \mathbb{E}(R|H\leftarrow 0)$
  • How can we predict the causal effect from observation?

Credit assignment as causal inference

What is a neuron's causal effect on reward, and so how should it change to improve performance? $$ \beta_i = \mathbb{E}(R| H_i \leftarrow 1) - \mathbb{E}(R| H_i \leftarrow 0) $$

$\Rightarrow$ How can a neuron perform causal inference?

Credit assignment as causal inference

One solution: Randomization

If independent (unconfounded) noise is added to the system, this can be correlated with reward for an estimate of its reward gradient

In fact, the REINFORCE algorithm correlates reward with independent pertubations in activity, $\xi^i$: $$ \mathbb{E}( R\xi^i ) \approx \sigma^2 \frac{\delta R}{\delta h^i} $$

But:

  • Requires each neuron measures an IID noise source, $\xi^i$, or knows its output relative to some expected output
  • Only well characterized in specific circuits e.g. birdsong learning (Fiete and Seung 2007)

Causal learning without randomization

An observation: decisions made with arbitary thresholds let us observe counterfactuals

Adapted from Moscoe et al, J Clin Epid 2015

Known as regression discontinuity design (RDD) in economics

Two more observations:

  1. A neuron only spikes if its input is above a threshold
  2. A spike can have a measurable effect on outcome and reward

Suggests regression discontinuity design can be used by a neuron to estimate its causal effect.

RDD for solving credit assignment

Lansdell and Kording, bioRxiv 2019

  • Inputs that place the neuron close to threshold are unbiased estimate of causal effect
  • Estimate piece-wise constant model: $$R = \gamma_i + \beta_i H_i$$

A small demonstration

The two-neuron network with noise correlations

A small demonstration

Can use RDD to estimate the causal effect

A small demonstration

Works in cases where a correlational estimator fails

$$\beta = \mathbb{E}(R|H\leftarrow 1) - \mathbb{E}(R|H\leftarrow 0), \quad \beta_{OD} = \mathbb{E}(R|H=1) - \mathbb{E}(R|H=0)$$

A small demonstration

Under some assumptions

$$ \frac{\partial R}{\partial w^i_j} \approx \frac{\partial H^i}{\partial w^i_j} \beta^i $$
  • Can relate causal effect to gradients $\Rightarrow$ derive stochastic gradient descent learning rule

  • Learning trajectories are less biased and converge faster

A larger example

Reward/cost function trains one neuron to have different firing rate from rest of population

  • Performance and learning curve independent of noise correlations
  • RDD-trained network learns more quickly which neuron is special

Application to brain-computer interface learning

  • In single-unit BCIs, individual neurons are trained through biofeedback
  • Here, causal effect of a neuron is known by construction
  • How does the network change specifically the control neuron's activity?
  • Must solve causal inference problem

Lansdell et al IEEE Trans NSRE 2020

Is this plausible?

  • It would require:
    • sub-threshold dependent plasticity
    • neuromodulator dependent plasticity

Ngezahayo et al 2000, Seol et al 2007

Why spike?

  • Neurons need to communicate over large distances

Calcium imaging in Hydra. Dupre and Yuste 2017

  • But computationally, a spiking discontinuity is inconvenient for learning
  • What are the comptuational benefits of spiking?
  • With RDD-based learning, spiking is a feature and not a bug

Part 1 summary

  • RDD can be used to estimate causal effects, and can provide a solution to the credit assignment problem in spiking neural networks
  • Shows a neuron can do causal inference without needing to randomize
  • Relies on the fact that neurons spike when input exceeds a threshold – spiking is a feature not a bug

Outline

  1. The neuronal credit assignment problem as causal inference
  2. Learning to solve the credit assignment problem

How to scale to large problems?

Richards et al Nature Neuroscience 2019

  • Either implicitly or explicitly, many learning rules use gradient information to optimize their weights
  • Vary according to their bias and variance
  • Backpropagation the standard for challenging ML problems

How to scale to large problems?

Richards et al Nature Neuroscience 2019

  • Reinforcement-learning based algorithms do not lead to efficient learning – high variance estimators
  • $\Rightarrow$ Need higher-dimensional error signal to learn from

How to scale to large problems?

  • Is this plausible?
    • Cortical structure has both feedforward and feedback connections
    • Pyramidal neurons contain both apical and basal compartments, allowing for separate sites of integration (Kording and Konig 2001, Guergiuev et al 2017)
  • How much feedback is needed?
$\Rightarrow$ Investigate in the setting of artificial neuron networks: $$ \mathbf{h}^i = \sigma(W^i\mathbf{h}^{i-1}) $$

Biologically implausible backpropagation

  • The gradient algorithm, backpropagation, suggests one form of such a network: $$ \mathbf{e}^i =\left((W^{i+1})^\mathsf{T} \mathbf{e}^{i+1}\right)\circ \sigma'(W^{i}\mathbf{h}^{i-1}) $$ with $\frac{\partial R}{\partial \mathbf{h}^i} = (W^{i+1})^T \mathbf{e}^{i+1}$
  • Low bias, low variance – 'the workhorse of deep learning'
  • But weights of this network would be transpose of the feedforward weights – e.g. require weight transport

Learning without weight transport

However:
  • Weight transport can be avoided by using random, fixed feedback weights, $B_i$ – feedback alignment
  • Works on small fully-connected networks
  • Doesn't work on deep networks, CNNs, networks with bottleneck layers
Suggests high-dimensional error signals help, even if very biased

$\Rightarrow$ Can we improve on feedback alignment by learning weights $B_i$?

Learning feedback weights with perturbations

Minimize $\mathbb{E}(\mathcal{L}(\mathbf{x}, \mathbf{y}))$. Assume activations are noisy: $$ \mathbf{h}_t^i = \sigma\left(\sum_k W^i_{\cdot k} \mathbf{h}_t^{i-1}\right) + c_h\xi^i_t $$ REINFORCE-type estimator of error gradient: \begin{equation*} \hat{\lambda}^i = (\tilde{\mathcal{L}}(\mathbf{x},\mathbf{y},\xi)-\mathcal{L}(\mathbf{x},\mathbf{y})) \frac{\xi^i}{c_h} \approx \frac{\partial \mathcal{L}}{\partial \mathbf{h}^i} \end{equation*} Train feedback network to provide a useful error signal: $$\begin{equation} \label{eq:lsq} \hat{B}^{i+1} = \text{argmin}_{B} \mathbb{E}\left\| B^T\mathbf{e}^{i+1} - \hat{\lambda^i} \right\|_2^2 \end{equation}$$

Learning feedback weights with perturbations

Update $B$ according to $$ \Delta B^i \propto (B^T\mathbf{e}^{i+1} - \hat{\lambda}^i)(\mathbf{e}^{i+1})^T $$ where $$ \mathbf{e}^i =\left((B^{i+1})^\mathsf{T} \mathbf{e}^{i+1}\right)\circ \sigma'(W^{i}\mathbf{h}^{i-1}) $$ Update $W$ according to $$ \Delta W^i \propto (\mathbf{e}^i)^T\mathbf{h}^{i-1} $$

Learning feedback weights with perturbations

If we only update $B$, then weights in the final layer converge to $W$, in the following way

Theorem 1: The least squares estimator \begin{equation*} (\hat{B}^{N+1})^T = \hat{\lambda}^N (\mathbf{e}^{N+1})^T\left(\mathbf{e}^{N+1}(\mathbf{e}^{N+1})^T\right)^{-1}, \end{equation*} converges to the true feedback matrix, in the sense that: $$ \lim_{c_h\to 0}\text{plim}_{T\to\infty} \hat{B}^{N+1} = W^{N+1}, $$ where $\text{plim}$ indicates convergence in probability.


Learning feedback weights with perturbations

If we only update $B$, then weights in all layers converge to $W$, for a linear network

Theorem 2: For $\sigma(x) = x$, the least squares estimator $$ \begin{equation*} (\hat{B}^{n})^T = \hat{\lambda}^{n-1} (\mathbf{\tilde{e}}^{n})^T\left(\mathbf{\tilde{e}}^{n}(\mathbf{\tilde{e}}^{n})^T\right)^{-1}\qquad 1 \le n \le N+1, \end{equation*}$$ converges to the true feedback matrix, in the sense that: $$ \lim_{c_h\to 0}\text{plim}_{T\to\infty} \hat{B}^{n} = W^{n}, \qquad 1 \le n \le N+1. $$

A small example

  • Test on a 4 layer network solving MNIST
  • Learns to more closely approximate true gradient than random weights

Lansdell, Prakash and Kording, ICLR 2020

A (slightly) larger example

  • Test on a 5 layer autoencoding network on MNIST
  • Feedback alignment fails to solve this task
  • Node perturbation learns faster than backprop w stochastic gradient descent
  • Comparable to BP with ADAM optimzer

A larger example

  • Also leads to improved performance on CNNs
  • (Too deep to propagate approximate signals through all layers
    $\Rightarrow$ Use direct feedback alignment instead)


dataset BP NP DFA
CIFAR10 76.9$\pm$0.1 74.8$\pm$0.2 72.4$\pm$0.2
CIFAR100 51.2$\pm$0.1 48.1$\pm$0.2 47.3$\pm$0.1
Mean test accuracy of CNN over 5 runs trained with backpropagation, node perturbation and direct feedback alignment (DFA)

$\Rightarrow$ Shows challenging computer vision problems can be solved without weight transport

Summary

  • Shown how:
    • neurons can use their spiking threshold to estimate their causal effect on reward
    • a perturbation-based learning rule can be used to train a feedback network to provide useful error information
  • Applications in:
    • Neuromorphic hardware – learning with spiking networks
    • Application specific integrated circuits (ASICs) – learning without weight transport
  • A combination of these can provide biologically plausible and scaleable learning systems

Acknowledgments


  • Konrad Kording (U Penn)
  • Kording lab
    • Ari Benjamin
    • David Rolnick
    • Roozbeh Farhoodi
    • Prashanth Prakash
  • Adrienne Fairhall (UW)
  • Fairhall lab
    • Rich Pang
    • Alison Duffy

RDD as a way for a neuron to solve credit assignment

Lansdell and Kording, bioRxiv 2019

Is this plausible?

  • Consistent with:
    • current models of sub-threshold dependent plasticity
    • current models of neuromodulator dependent plasticity

How to test?

  • Over a fixed time window a reward is administered when neuron spikes
  • Stimuli are identified which place the neuron's input drive close to spiking threshold.
  • RDD-based learning predicts an increase synaptic changes for a set of stimuli containing a high proportion of near threshold inputs, but that keeps overall firing rate constant.
Ok, so those are the two projects I want to go into in some detail. But now I want to take a step back, and look at how these projects fit into a broader picture, and outline future work. So these projects fit into a general topic which is combinations of causal inference and decision making.