Abracadabra

TD-VAE [ICLR 2019]

简介

【笔记版】

今天要讲的是ICLR2019中DeepMind的一个工作,TD-VAE,一个序列生成模型。通过引入强化学习中时序差分以及变分自动编码器,来实现从当前时间步到未来时间步的预测。这里值得注意的是,TD-VAE并不是一个固定时间步的序列生成模型(当然如果训练时喂的训练数据是一个时间间隔固定的序列数据,那么训练出的模型就是固定时间步的序列生成模型),即其生成的数据时间间隔不是一个固定的时间步,而是随机的。如果想生成数据的时间间隔可控,那么可以在前向模型的建模中显式地将时间步作为变量即可。

这篇论文的作者认为,一个序列生成模型需要具备以下三点属性:

  • 这个模型应该学习一个数据的抽象状态表示并且在状态空间中进行预测,而不是在观察空间进行预测。
  • 这个模型应该学习一个置信状态,这个状态需要包含目前为止智能体对于周围环境的所有感知信息。置信状态相当于状态表示的隐变量。
  • 这个模型应该表现出时序抽象,既能够直接预测多个时间步之后的状态,也能够只通过两个独立的时间点进行训练而不需要中间所有时间点的信息。

优化目标

TD-VAE的目标便是优化以下对数条件似然:
$$
\log p(x_t|x_{<t})
$$
这里假设$x_t$可以通过该时间步以及上一个时间步的状态表示$z_t$和$z_{t-1}$推断得出,类似于VAE中损失函数的推导过程,这里同样引入ELBO,具体推导过程如下图:

推导过程

推导过程

推导过程

推导过程

最后的损失函数包含以下几个部分:

损失函数1

然后我们把两个连续时间步的状态表示换为两个任意时刻的状态表示:

损失函数2

这实质上是如下VAE的损失函数:

VAE

其中$t2>t1$。整个损失函数可以直观地解释为以下四个部分组成:

直观解释1

直观解释2

训练时的计算图如下所示:

计算图

最后在三个不同任务上的实验结果:

直观解释1

直观解释1

直观解释1