Prioritized Experience Replay

Prioritized Experience Replay

One of the possible improvements already acknowledged in the original research2 lays in the way experience is used. When treating all samples the same, we are not using the fact that we can learn more from some transitions than from others. Prioritized Experience Replay3(PER) is one strategy that tries to leverage this fact by changing the sampling distribution.

The main idea is that we prefer transitions that does not fit well to our current estimate of the Q function, because these are the transitions that we can learn most from. This reflects a simple intuition from our real world – if we encounter a situation that really differs from our expectation, we think about it over and over and change our model until it fits.

We can define an error of a sample $S = (s, a, r, s’)$ as a distance between the $Q(s, a)$ and its target $T(S)$:
error = |Q(s, a) - T(S)|
For DDQN described above, $T$ it would be:
T(S) = r + \gamma \tilde{Q}(s’, argmax_a Q(s’, a))
We will store this error in the agent’s memory along with every sample and update it with each learning step.

One of the possible approaches to PER is proportional prioritization. The error is first converted to priority using this formula:
p = (error + \epsilon)^\alpha
Epsilon $\epsilon$ is a small positive constant that ensures that no transition has zero priority.
Alpha, $0 \leq \alpha \leq 1$, controls the difference between high and low error. It determines how much prioritization is used. With $\alpha$ we would get the uniform case.

Priority is translated to probability of being chosen for replay. A sample $i$ has a probability of being picked during the experience replay determined by a formula:
P_i = \frac{p_i}{\sum_k p_k}
The algorithm is simple – during each learning step we will get a batch of samples with this probability distribution and train our network on it. We only need an effective way of storing these priorities and sampling from them.

Initialization and new transitions

The original paper says that new transitions come without a known error3, but I argue that with definition given above, we can compute it with a simple forward pass as it arrives. It’s also effective, because high value transitions are discovered immediately.

Another thing is the initialization. Remember that before the learning itself, we fill the memory using random agent. But this agent does not use any neural network, so how could we estimate any error? We can use a fact that untrained neural network is likely to return a value around zero for every input. In this case the error formula becomes very simple:
error = |Q(s, a) - T(S)| = |Q(s, a) - r - \gamma \tilde{Q}(s’, argmax_a Q(s’, a))| = | r |
The error in this case is simply the reward experienced in a given sample. Indeed, the transitions where the agent experienced any reward intuitively seem to be very promising.

Efficient implementation

So how do we store the experience and effectively sample from it?

A naive implementation would be to have all samples in an array sorted according to their priorities. A random number s, $0 \leq s \leq \sum_k p_k$, would be picked and the array would be walked left to right, summing up a priority of the current element until the sum is exceeded and that element is chosen. This will select a sample with the desired probability distribution.

Sorted experience

But this would have a terrible efficiency: $O(n log n)$ for insertion and update and O$(n) $for sampling.

A first important observation is that we don’t have to actually store the array sorted. Unsorted array would do just as well. Elements with higher priorities are still picked with higher probability.

Unsorted experience

This releases the need for sorting, improving the algorithm to O(1) for insertion and update.

But the O(n) for sampling is still too high. We can further improve our algorithm by using a different data structure. We can store our samples in unsorted sum tree – a binary tree data structure where the parent’s value is the sum of its children. The samples themselves are stored in the leaf nodes.

Update of a leaf node involves propagating a value difference up the tree, obtaining O(log n). Sampling follows the thought process of the array case, but achieves O(log n). For a value s, $0 \leq s \leq \sum_k p_k$, we use the following algorithm (pseudo code):

def retrieve(n, s):
if n is leaf_node: return n
if n.left.val >= s: return retrieve(n.left, s)
else: return retrieve(n.right, s - n.left.val)

Following picture illustrates sampling from a tree with s = 24:

Sampling from sum tree

With this effective implementation we can use large memory sizes, with hundreds of thousands or millions of samples.

For known capacity, this sum tree data structure can be backed by an array. Its reference implementation containing 50 lines of code is available on GitHub.