12 Example 1
We will simulate a two neuron network in which neuron 1 projects to neuron 2, the first neuron is driven by an external input, and the connection weight between neuron 1 and neuron 2 is subject to reinforcement learning.
We will use the following
python
code based on the previous lecture.The key additions to the code from previous lecture are mostly found in the
update_weight_rl()
function.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
def init_arrays():
r_predicted = np.zeros(n_trials)
r_obtained = np.zeros(n_trials)
delta = np.zeros(n_trials)
r_obtained[:n_trials // 4] = 1
r_obtained[(2 * n_trials // 4):(3 * n_trials // 4)] = 1
v1 = np.zeros(n)
u1 = np.zeros(n)
g1 = np.zeros(n)
spike1 = np.zeros(n)
v1[0] = vr
v2 = np.zeros(n)
u2 = np.zeros(n)
g2 = np.zeros(n)
spike2 = np.zeros(n)
v2[0] = vr
w_01 = 0.4 * np.ones(n_trials)
w_12 = 0.4 * np.ones(n_trials)
g_record = np.zeros((n_trials, n))
v1_record = np.zeros((n_trials, n))
g1_record = np.zeros((n_trials, n))
v2_record = np.zeros((n_trials, n))
g2_record = np.zeros((n_trials, n))
return {
'r_predicted': r_predicted,
'r_obtained': r_obtained,
'delta': delta,
'v1': v1,
'u1': u1,
'g1': g1,
'spike1': spike1,
'v2': v2,
'u2': u2,
'g2': g2,
'spike2': spike2,
'w_01': w_01,
'w_12': w_12,
'g_record': g_record,
'v1_record': v1_record,
'g1_record': g1_record,
'v2_record': v2_record,
'g2_record': g2_record
}
def plot_results():
fig = plt.figure(figsize=(8, 8))
gs = gridspec.GridSpec(5, 2)
ax00 = fig.add_subplot(gs[0, 0])
ax10 = fig.add_subplot(gs[1, 0])
ax20 = fig.add_subplot(gs[2, 0])
ax01 = fig.add_subplot(gs[0, 1])
ax11 = fig.add_subplot(gs[1, 1])
ax21 = fig.add_subplot(gs[2, 1])
ax3 = fig.add_subplot(gs[3, :])
ax4 = fig.add_subplot(gs[4, :])
ax1 = ax00
ax2 = ax00.twinx()
ax2.plot(t, g_record[0], 'C0', label='external g')
ax2.legend(loc='lower right')
ax1.set_title('Trial 1')
ax1.set_xlabel('time (t)')
ax1 = ax10
ax2 = ax10.twinx()
ax1.plot(t, v1_record[0], 'C0', label='v1')
ax2.plot(t, g1_record[0], 'C1', label='g1')
ax1.legend(loc='upper right')
ax2.legend(loc='lower right')
ax1.set_xlabel('time (t)')
ax1 = ax20
ax2 = ax20.twinx()
ax1.plot(t, v2_record[0], 'C0', label='v2')
ax2.plot(t, g2_record[0], 'C1', label='g2')
ax1.legend(loc='upper right')
ax2.legend(loc='lower right')
ax1.set_xlabel('time (t)')
ax1 = ax01
ax2 = ax01.twinx()
ax2.plot(t, g_record[0], 'C0', label='external g')
ax2.legend(loc='lower right')
ax1.set_title('Trial n')
ax1.set_xlabel('time (t)')
ax1 = ax11
ax2 = ax11.twinx()
ax1.plot(t, v1_record[-2], 'C0', label='v1')
ax2.plot(t, g1_record[-2], 'C1', label='g1')
ax1.legend(loc='upper right')
ax2.legend(loc='lower right')
ax1.set_xlabel('time (t)')
ax1 = ax21
ax2 = ax21.twinx()
ax1.plot(t, v2_record[-2], 'C0', label='v2')
ax2.plot(t, g2_record[-2], 'C1', label='g2')
ax1.legend(loc='upper right')
ax2.legend(loc='lower right')
ax1.set_xlabel('time (t)')
ax3.plot(np.arange(0, n_trials, 1), r_obtained, label='r_obtained')
ax3.plot(np.arange(0, n_trials, 1), r_predicted, label='r_predicted')
ax3.plot(np.arange(0, n_trials, 1), delta, label='delta')
ax3.set_xlabel('Trial')
ax3.set_ylabel('')
ax3.legend()
ax4.plot(np.arange(0, n_trials, 1), w_12)
ax4.set_xlabel('Trial')
ax4.set_ylabel('Synaptic Weight (w)')
plt.tight_layout()
plt.show()
def simulate_network(update_weight_func):
global trl, r_obtained, r_predicted
for j in range(n_trials - 1):
trl = j
for i in range(1, n):
dt = t[i] - t[i - 1]
# external input
dgdt = (-g[i - 1] + psp_amp * spike[i - 1]) / psp_decay
g[i] = g[i - 1] + dgdt * dt
# neuron 1
dvdt1 = (k * (v1[i - 1] - vr) *
(v1[i - 1] - vt) - u1[i - 1] + w_01[trl] * g[i - 1]) / C
dudt1 = a * (b * (v1[i - 1] - vr) - u1[i - 1])
dgdt1 = (-g1[i - 1] + psp_amp * spike1[i - 1]) / psp_decay
v1[i] = v1[i - 1] + dvdt1 * dt
u1[i] = u1[i - 1] + dudt1 * dt
g1[i] = g1[i - 1] + dgdt1 * dt
if v1[i] >= vpeak:
v1[i - 1] = vpeak
v1[i] = c
u1[i] = u1[i] + d
spike1[i] = 1
# neuron 2
dvdt2 = (k * (v2[i - 1] - vr) *
(v2[i - 1] - vt) - u2[i - 1] + w_12[trl] * g1[i - 1]) / C
dudt2 = a * (b * (v2[i - 1] - vr) - u2[i - 1])
dgdt2 = (-g2[i - 1] + psp_amp * spike2[i - 1]) / psp_decay
v2[i] = v2[i - 1] + dvdt2 * dt
u2[i] = u2[i - 1] + dudt2 * dt
g2[i] = g2[i - 1] + dgdt2 * dt
if v2[i] >= vpeak:
v2[i - 1] = vpeak
v2[i] = c
u2[i] = u2[i] + d
spike2[i] = 1
# update synaptic weights
delta_w = update_weight_func()
w_12[trl + 1] = w_12[trl] + delta_w
# store trial info
g_record[trl, :] = g
v1_record[trl, :] = v1
g1_record[trl, :] = g1
v2_record[trl, :] = v2
g2_record[trl, :] = g2
plot_results()
def update_weight_rl():
global trl, r_obtained, r_predicted
delta[trl] = r_obtained[trl] - r_predicted[trl]
r_predicted[trl + 1] = r_predicted[trl] + gamma * delta[trl]
pre = g1.sum()
post = g2.sum()
delta_w = alpha * pre * post * delta[trl]
return delta_w
n_trials = 100
trl = 0
tau = 0.1
T = 100
t = np.arange(0, T, tau)
n = t.shape[0]
C = 50
vr = -80
vt = -25
vpeak = 40
k = 1
a = 0.01
b = -20
c = -55
d = 150
psp_amp = 1e5
psp_decay = 10
g = np.zeros(n)
spike = np.zeros(n)
spike[200:800:20] = 1
alpha = 3e-14
beta = 3e-14
gamma = 0.1
array_dict = init_arrays()
r_predicted = array_dict['r_predicted']
r_obtained = array_dict['r_obtained']
delta = array_dict['delta']
v1 = array_dict['v1']
u1 = array_dict['u1']
g1 = array_dict['g1']
spike1 = array_dict['spike1']
v2 = array_dict['v2']
u2 = array_dict['u2']
g2 = array_dict['g2']
spike2 = array_dict['spike2']
w_01 = array_dict['w_01']
w_12 = array_dict['w_12']
g_record = array_dict['g_record']
v1_record = array_dict['v1_record']
g1_record = array_dict['g1_record']
v2_record = array_dict['v2_record']
g2_record = array_dict['g2_record']
update_weight_func = update_weight_rl
simulate_network(update_weight_func)