The neuronal credit assignment problem as causal inference
Ben Lansdell, Bioengineering UPenn
Drexel University. February 27th 2020
Two complementary tasks to understand intelligence
1. Build artificial systems
2. Study human intelligence
Learning is central to both human and artificial intelligence
$\Rightarrow$ Advances in each domain can inspire the other
Machine learning, neuroscience, and causality
Causation relates to a number of challenges in both machine learning and neuroscience
Messerli, N Engl J Med 2012
In ML:
Causal models are more robust to changes in environment/distribution: better transfer, generalization
Fairness: strong associations are not causal, and may be unfair/biased/prejudiced
Safety: observational data may not say what happens when we act/intervene/change distributions
Machine learning, neuroscience, and causality
Causation relates to a number of challenges in both machine learning and neuroscience
In neuroscience:
Data analysis:
neural datasets generally hugely undersampled – confounding, interpretation
increased ability to perturb specific circuits
Efficient learning, transfer, generalization
Causal learning
Machine learning, neuroscience, and causality
Causation relates to a number of challenges in both machine learning and neuroscience
Claim: progress in both machine learning and neuroscience can come from explicitly casting problems as causal learning problems
Outline
The neuronal credit assignment problem as causal inference
Learning to solve the credit assignment problem
Learning in the brain
Learning in the brain
Find parameters $w$ that minimize a loss/maximize a reward function, $R$
Learning in the brain
What are the synaptic update rules used by neurons that provide efficient and flexible learning?
Must:
Be consistent with known neurophysiology
Be good enough at learning complicated tasks
The neuronal credit assignment problem
To learn, a neuron must know its effect on the reward function
In spiking neural networks, this means something like:
If, for a given input, a spike increases the reward, the weights leading to that spike should increase
If, for a given input, a spike decreases the reward, the weights leading to that spike should decrease
The problem: noise correlations and confounding
$\Rightarrow$ Viewing learning as a causal inference problem may provide insight
Causality
Defined in terms of counterfactuals or interventions
The causal effect: $\beta = \mathbb{E}(R|H\leftarrow 1) - \mathbb{E}(R|H\leftarrow 0)$
How can we predict the causal effect from observation?
Causality
Defined in terms of counterfactuals or interventions
The causal effect: $\beta = \mathbb{E}(R|H\leftarrow 1) - \mathbb{E}(R|H\leftarrow 0)$
How can we predict the causal effect from observation?
Credit assignment as causal inference
What is a neuron's causal effect on reward, and so how should it change to improve performance?
$$
\beta_i = \mathbb{E}(R| H_i \leftarrow 1) - \mathbb{E}(R| H_i \leftarrow 0)
$$
$\Rightarrow$ How can a neuron perform causal inference?
Credit assignment as causal inference
One solution: Randomization
If independent (unconfounded) noise is added to the system, this can be correlated with reward for an estimate of its reward gradient
In fact, the REINFORCE algorithm correlates reward with independent pertubations in activity, $\xi^i$:
$$
\mathbb{E}( R\xi^i ) \approx \sigma^2 \frac{\delta R}{\delta h^i}
$$
But:
Requires each neuron measures an IID noise source, $\xi^i$, or knows its output relative to some expected output
Only well characterized in specific circuits e.g. birdsong learning (Fiete and Seung 2007)
Causal learning without randomization
An observation: decisions made with arbitary thresholds let us observe counterfactuals
Adapted from Moscoe et al, J Clin Epid 2015
Known as regression discontinuity design (RDD) in economics
Two more observations:
A neuron only spikes if its input is above a threshold
A spike can have a measurable effect on outcome and reward
Suggests regression discontinuity design can be used by a neuron to estimate its causal effect.
RDD for solving credit assignment
Lansdell and Kording, bioRxiv 2019
Inputs that place the neuron close to threshold are unbiased estimate of causal effect
Can relate causal effect to gradients $\Rightarrow$ derive stochastic gradient descent learning rule
Learning trajectories are less biased and converge faster
A larger example
Reward/cost function trains one neuron to have different firing rate from rest of population
Performance and learning curve independent of noise correlations
RDD-trained network learns more quickly which neuron is special
Application to brain-computer interface learning
In single-unit BCIs, individual neurons are trained through biofeedback
Here, causal effect of a neuron is known by construction
How does the network change specifically the control neuron's activity?
Must solve causal inference problem
Lansdell et al IEEE Trans NSRE 2020
Is this plausible?
It would require:
sub-threshold dependent plasticity
neuromodulator dependent plasticity
Ngezahayo et al 2000, Seol et al 2007
Why spike?
Neurons need to communicate over large distances
Calcium imaging in Hydra. Dupre and Yuste 2017
But computationally, a spiking discontinuity is inconvenient for learning
What are the comptuational benefits of spiking?
With RDD-based learning, spiking is a feature and not a bug
Part 1 summary
RDD can be used to estimate causal effects, and can provide a solution to the credit assignment problem in spiking neural networks
Shows a neuron can do causal inference without needing to randomize
Relies on the fact that neurons spike when input exceeds a threshold – spiking is a feature not a bug
Outline
The neuronal credit assignment problem as causal inference
Learning to solve the credit assignment problem
How to scale to large problems?
Richards et al Nature Neuroscience 2019
Either implicitly or explicitly, many learning rules use gradient information to optimize their weights
Vary according to their bias and variance
Backpropagation the standard for challenging ML problems
How to scale to large problems?
Richards et al Nature Neuroscience 2019
Reinforcement-learning based algorithms do not lead to efficient learning – high variance estimators
$\Rightarrow$ Need higher-dimensional error signal to learn from
How to scale to large problems?
Is this plausible?
Cortical structure has both feedforward and feedback connections
Pyramidal neurons contain both apical and basal compartments, allowing for separate sites of integration (Kording and Konig 2001, Guergiuev et al 2017)
How much feedback is needed?
$\Rightarrow$ Investigate in the setting of artificial neuron networks:
$$
\mathbf{h}^i = \sigma(W^i\mathbf{h}^{i-1})
$$
Biologically implausible backpropagation
The gradient algorithm, backpropagation, suggests one form of such a network:
$$
\mathbf{e}^i =\left((W^{i+1})^\mathsf{T} \mathbf{e}^{i+1}\right)\circ \sigma'(W^{i}\mathbf{h}^{i-1})
$$
with $\frac{\partial R}{\partial \mathbf{h}^i} = (W^{i+1})^T \mathbf{e}^{i+1}$
Low bias, low variance – 'the workhorse of deep learning'
But weights of this network would be transpose of the feedforward weights – e.g. require weight transport
Learning without weight transport
However:
Weight transport can be avoided by using random, fixed feedback weights, $B_i$ – feedback alignment
Works on small fully-connected networks
Doesn't work on deep networks, CNNs, networks with bottleneck layers
Suggests high-dimensional error signals help, even if very biased
$\Rightarrow$ Can we improve on feedback alignment by learning weights $B_i$?
Update $B$ according to
$$
\Delta B^i \propto (B^T\mathbf{e}^{i+1} - \hat{\lambda}^i)(\mathbf{e}^{i+1})^T
$$
where
$$
\mathbf{e}^i =\left((B^{i+1})^\mathsf{T} \mathbf{e}^{i+1}\right)\circ \sigma'(W^{i}\mathbf{h}^{i-1})
$$
Update $W$ according to
$$
\Delta W^i \propto (\mathbf{e}^i)^T\mathbf{h}^{i-1}
$$
Learning feedback weights with perturbations
If we only update $B$, then weights in the final layer converge to $W$, in the following way
Theorem 1: The least squares estimator
\begin{equation*}
(\hat{B}^{N+1})^T = \hat{\lambda}^N (\mathbf{e}^{N+1})^T\left(\mathbf{e}^{N+1}(\mathbf{e}^{N+1})^T\right)^{-1},
\end{equation*}
converges to the true feedback matrix, in the sense that:
$$
\lim_{c_h\to 0}\text{plim}_{T\to\infty} \hat{B}^{N+1} = W^{N+1},
$$
where $\text{plim}$ indicates convergence in probability.
Learning feedback weights with perturbations
If we only update $B$, then weights in all layers converge to $W$, for a linear network
Theorem 2: For $\sigma(x) = x$, the least squares estimator
$$
\begin{equation*}
(\hat{B}^{n})^T = \hat{\lambda}^{n-1} (\mathbf{\tilde{e}}^{n})^T\left(\mathbf{\tilde{e}}^{n}(\mathbf{\tilde{e}}^{n})^T\right)^{-1}\qquad 1 \le n \le N+1,
\end{equation*}$$
converges to the true feedback matrix, in the sense that:
$$
\lim_{c_h\to 0}\text{plim}_{T\to\infty} \hat{B}^{n} = W^{n}, \qquad 1 \le n \le N+1.
$$
A small example
Test on a 4 layer network solving MNIST
Learns to more closely approximate true gradient than random weights
Lansdell, Prakash and Kording, ICLR 2020
A (slightly) larger example
Test on a 5 layer autoencoding network on MNIST
Feedback alignment fails to solve this task
Node perturbation learns faster than backprop w stochastic gradient descent
Comparable to BP with ADAM optimzer
A larger example
Also leads to improved performance on CNNs
(Too deep to propagate approximate signals through all layers
$\Rightarrow$ Use direct feedback alignment instead)
dataset
BP
NP
DFA
CIFAR10
76.9$\pm$0.1
74.8$\pm$0.2
72.4$\pm$0.2
CIFAR100
51.2$\pm$0.1
48.1$\pm$0.2
47.3$\pm$0.1
Mean test accuracy of CNN over 5 runs trained with backpropagation, node perturbation and direct feedback alignment (DFA)
$\Rightarrow$ Shows challenging computer vision problems can be solved without weight transport
Summary
Shown how:
neurons can use their spiking threshold to estimate their causal effect on reward
a perturbation-based learning rule can be used to train a feedback network to provide useful error information
Applications in:
Neuromorphic hardware – learning with spiking networks
Application specific integrated circuits (ASICs) – learning without weight transport
A combination of these can provide biologically plausible and scaleable learning systems
Acknowledgments
Konrad Kording (U Penn)
Kording lab
Ari Benjamin
David Rolnick
Roozbeh Farhoodi
Prashanth Prakash
Adrienne Fairhall (UW)
Fairhall lab
Rich Pang
Alison Duffy
Chet Moritz (UW)
Ivana Milovanovic (UW)
Cooper Mellema (UT Austin)
Eberhard Fetz (UW)
RDD as a way for a neuron to solve credit assignment
Lansdell and Kording, bioRxiv 2019
Is this plausible?
Consistent with:
current models of sub-threshold dependent plasticity
current models of neuromodulator dependent plasticity
How to test?
Over a fixed time window a reward is administered when neuron spikes
Stimuli are identified which place the neuron's input drive close to spiking threshold.
RDD-based learning predicts an increase synaptic changes for a set of stimuli containing a high proportion of near threshold inputs, but that keeps overall firing rate constant.