Nature DQN
DeepMind, 2015
Paper
summary
- use CNN to estimate state value function
- adopt replay buffer to train agent
- use separate target network to compute TD error
formulas
- TD target: $r+\gamma\max_{a’} Q(s’,a’;\theta^-)$
- TD error: $\Delta=r+\gamma\max_{a’} Q(s’,a’;\theta^-)-Q(s,a;\theta)$
- update rule: $\theta_{t+1}=\theta_t+\alpha\Delta\nabla Q(s,a;\theta)$
- loss function: $loss=\Delta^2$
details
RL is known to be unstable or even to diverge when a nonlinear function approximator such as a NN is used to represent the action-value function ($Q$ function). There are several causes:
- the correlations present in the sequence of observations
- small updates to $Q$ may significantly change the policy and therefore change the data distribution and the correlations between $Q$ and the target values $r+\gamma\max_{a’}Q(s’,a’)$
To solve these, propose two key ideas:
- experience replay: inspired by biological mechanism, randomizes over the data, removing correlations in the observations sequence and smoothing over changes in the data distribution
- iteratively update $Q$: iteratively update $Q$ towards target values that are periodically updated, reducing the correlations with the target
loss function: . $\theta_i$ is the parameter of Q-network, $\theta_i^-$ is the parameter of target Q-network, which is updated with $\theta_i$ every C steps and is held fixed between individual updates
disadvantages
Overestimate action values under certain conditions (because of the max operation); use the same values both to select and to evaluate an action.
pseudocode
code
- Pytorch
#create two networks and synchronize
current_model, target_model = DQN(num_states, num_actions), DQN(num_states, num_actions)
def update_model(current_model, target_model):
target_model.load_state_dict(current_model.state_dict())
update_model(current_model, target_model)
optimizer = optim.Adam(current_model.parameters())
#sample s,a,r,s2,t from replay buffer
s, a, r, s2, t = replay_buffer.sample(batchsize)
q_values = current_model(s)
next_q_values = target_model(s2)
#choose the action state value
q_value = q_values.gather(1, a.unsqueeze(1)).squeeze(1)
next_q_value = next_q_values.max(1)[0]
expected_q_value = r + gamma * next_q_value * (1 - t)
#compute loss
loss = (q_value - Variable(expected_q_value.data)).pow(2).mean()
#update params
optimizer.zero_grad()
loss.backward()
optimizer.step()
Full version can be found here 😀.
In the next post, several variants of DQN will be introduced.