On using Huber loss in (Deep) Q-learning

I’ve been recently working on a problem where I put a plain DQN to use. The problem is very simple, deterministic, partially observable and states are quite low-dimensional. The agent however can’t tell some states apart and so it’s effectively stochastic in the eyes of the agent.

Because the problem was quite simple, I just expected the network to learn very good representation of the Q-function over the whole state space.

And I was surprised that this vanilla DQN totally failed in this problem. Not in the sense it was too difficult, on the contrary – the algorithm converged and was highly certain on all the Q-values it found. But these Q-values were totally wrong. I couldn’t get my head around it, but then I tracked it down to a simple cause: Pseudo-Huber loss.

Edit: Based on the discussion, the original Huber loss with appropriate δ parameter is correct to use. The following article however stays true for L1 and pseudo-huber loss.

Correct Q-values

Let’s first look at a Q-function we are trying to estimate. The optimal and correct Q-function meets the Bellman equation:

Q(s,a) = E_{s'}[ r + max_a\;Q(s',a) ]

The expectation is over all possible next states weighted with their probability of occurrence, basically a weighted mean. In deterministic and fully-observable environments, there is always only one possible next state s’ and so the expectation is not used. However, in a stochastic environment, the agent can experience a multiple different transitions even if it performs the same action in the same state s.

Let’s consider a simple example. Imagine an environment where there is only one state s with only one possible action which terminates the episode. The transition gives a reward +1 in 2/3 cases (with probability 66.6%) and -1 in 1/3 cases (33.3%).

The correct Q value is obviously 1/3:

Q(s,a) = E_{s'}[ r ] = \frac{2}{3} * 1 + \frac{1}{3} * (-1) = \frac{1}{3}

Learned Q-values

When we are learning the approximated Q-function, we are fitting a value x such as it minimizes a defined cost with regards to some target values.

In our example, the only action is terminal, so we can use observed rewards directly as a target y. When we sample y with the distribution described above, we expect that (in limit) there will be 2/3 of +1 rewards and 1/3 of -1 rewards. In the following formulas, we can aggregate these as weights.

Let’s look at three choices of the cost function: L2 (MSE), L1 (MAE) and Pseudo-Huber loss.

L2 (MSE) is defined as:

MSE = \frac{1}{n}\sum_{i=1}^{n}(x - y_i)^2

If we use the targets from our example, the formula becomes:

MSE = \frac{1}{3}[ 2(x - 1)^2 + (x + 1)^2 ]

mse

This function has a minimum at x=1/3 (check here), which is a mean of targets. This is not a coincidence, the minimum of MSE is always at the mean of targets. Notice that it corresponds to the correct Q-value computed above.

L1 (MAE) replaces the squared value with absolute value:

MAE = \frac{1}{n}\sum_{i=1}^{n}|x - y_i|

In our example, it becomes:

MAE = \frac{1}{3}[ 2 |x - 1|^2 + |x + 1|^2 ]

mae

The solution is different, x=1 (check here). The solution to minimizing the MAE is a median. That is, the value that occurs the most.

Pseudo-Huber loss:

L_{Huber} = \frac{1}{n}\sum_{i=1}^{n}[\sqrt{1 + (x - y_i)^2} - 1]

With our targets, it becomes:

L_{Huber} = \frac{1}{3}[ 2 (\sqrt{1 + (x - 1)^2} - 1) + (\sqrt{1 + (x + 1)^2} - 1) ]

huber

The solution to minimization of this formula is somewhere between L1 and L2. For the given data it’s x=0.538 (check here). Note that Pseudo-Huber loss function has a parameter δ which influences the solution. I arbitrarily chose δ=1.

The problem

Now we are starting to see the problem. The L1 or Pseudo-Huber loss simply can’t (by definition) find the correct Q-values. One can argue that in this simple example it does not matter, because the episode terminates anyway. But the fundamental problem is that this algorithm converges to incorrect Q-values (in stochastic environments).

Now I wonder whether the Atari domain is deterministic or not. And even if it was stochastic, it’s a question how much this matters, though. The problem might be diminished by complexity of the environment. Does it really matter if a creature which I am trying to shoot dodges to left with probability 70% and to right with 30%? I don’t know and it’s an interesting question.

What I’m trying to say is that you shouldn’t use arbitrary losses in deep Q-learning without thinking of the consequences. The problem only shows itself in either stochastic environments or in partial-observable environments if the agent can’t tell states apart – when perceptual-aliasing occurs.

Another question is, what is actually the added value of using the Huber loss? Is it favoured just because it’s robust to outliers? Or does it help to mitigate the gradient explosion problem? These are all questions deserving some serious research.

Analogy to tabular Q-learning

We can demonstrate the same problem also in table Q-learning. There, the update rule looks as follows:

Q(s,a) \xleftarrow{} Q(s,a) + \alpha * error

where the error is:

error = r + max_a\;Q(s',a) - Q(s,a)

One can argue that this error is essentially a gradient, which originated from a L2 loss. In case we would like to have a L1-like behaviour, the error would change to:

error = \begin{cases}  +1 & \text{when } max_a\;Q(s',a) - Q(s,a) > 0 \\ -1 & \text{when } max_a\;Q(s',a) - Q(s,a) < 0  \end{cases}

I hypothesise that with this L1-like error, tabular Q-learning would too converge to median values, which are obviously incorrect.

Discussion

If you’d like to participate in a discussion to this problem, I set up a reddit discussion here. I invite you to add your comments there.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s