一、介绍
概率序列建模:生成不同长度的序列(仅生成单个token或通过自回归采样生成"无限数量"的token)。在各种机器学习应用中起着至关重要的作用,包括自然语言处理,视频预测。
当前的对下一个token预测模型是通过Teacher Forcing进行训练的,
Teacher Forcing
RNN最初的训练方式:上一个state的输出作为下个state的输入
训练迭代过程早期的RNN预测能力非常弱,几乎不能给出好的生成结果。如果某一个单词产生了错误的结果,必然会影响后面一片单词的学习。从而导致学习速度变慢,难以收敛。
teacher-forcing 在训练网络过程中,每次不使用上一个state的输出作为下一个state的输入,而是直接使用训练数据的标准答案(ground truth)的对应上一项作为下一个state的输入。
给定如下输入序列:
[START] Mary had a little lamb whose fleece was white as snow [END]
free-running:如果一开始生成"a",之后作为输入来生成下一个单词,模型就偏离正轨。因为生成的错误结果,会导致后续的学习都受到不好的影响,导致学习速度变慢,模型也变得不稳定。
而使用teacher-forcing,模型生成一个"a",可以在计算了error之后,丢弃这个输出,把"Marry"作为后续的输入。该模型将更正模型训练过程中的统计属性,更快地学会生成正确的序列。
其中模型根据真实的先前tokens的历史信息从而预测下一个token。这导致了两个局限性:
1️⃣teacher-forcing过于依赖ground truth数据,在训练过程中,模型会有较好的效果,但是在测试的时候因为不能得到ground truth的支持,所以如果目前生成的序列在训练过程中有很大不同,模型就会变得脆弱。
2️⃣当前的下一个token预测模型在连续数据上很容易变得不稳定。例如,当试图在自动回归生成比训练时帧数更长的视频时,帧到帧预测中的微小错误会累积并且使得模型发散。
全序列扩散似乎提供了一种解决方案。通常用于视频生成和长期规划,人们通过扩散它们的串联来直接模拟固定数量的token的联合分布,其中所有token的噪声水平都是相同的。它们能依据指导采样到理想的序列。它们还擅长生成连续信号,如视频。然而,全序列扩散是通过非因果、未屏蔽的架构普遍参数化的。这不仅将采样限制为与训练相同的完整序列(而不是可变长度生成)之外,限制了子序列生成的可能性(图1)。
图 1:Diffusion Forcing 的能力。如今,不同的应用,如语言建模或视频生成,根据其各自的独特能力,依赖于自回归下一个token预测或全序列扩散。本文所提出的扩散强迫是一种新颖的序列生成模型,具有两种模型类型的关键优势。
在本文中,作者介绍了一种训练和采样范式——扩散强制(Diffusion Forcing, DF)。在该范式中,每个token都与一个随机、独立的噪声水平相关联,并且可以根据shedules(shedules当中的每个token是任意的、独立的),通过共享的下一个token或下几个token预测模型进行去噪。本文的方法源于这样的观察:对token加入噪声是一种部分掩码形式——零噪声意味着token未被遮掩,完全噪声则完全掩盖token。因此,DF迫使模型学习“去掩码”不同噪声程度的token的任何集合(见图2)。同时,通过将预测参数化为下一token预测模型的组合,该系统能够灵活地生成不同长度的序列,并且能够根据指导推广到新的轨迹(见图1)。
此外作者还将 DF 实现为因果扩散强制(Causal Diffusion Forcing, CDF),在这种方法中,未来的token通过因果架构依据过去的token来生成序列。在采样过程中,CDF 逐步将一系列高斯噪声帧去噪成干净的样本,不同的帧在每个去噪步骤中可能具有不同的噪声水平。与下一token预测模型类似,CDF 可以生成可变长度的序列;与下一token预测不同的是,它能够稳定地从紧接的下一个token生成到数千个将来的token——即使是连续token。此外,像完整序列扩散一样,它接受指引以生成高回报的结果。通过协同利用因果性、灵活的范围和可变噪声计划,CDF 实现了一种新的能力,蒙特卡洛树引导(Monte Carlo Tree Guidance, MCTG),与非因果完整序列扩散模型相比,这种方法显著改进了高回报生成的采样。图1概述了这些能力。
二、扩散强制(Diffusion Forcing)
将噪声处理为部分掩码:将任何token集合(无论是否顺序)视为按 \(\textcolor{DarkCyan}{t}\) 索引的有序集。然后,使用Teacher Forcing训练下一个token预测模型可以被解释为在时间 \(\textcolor{DarkCyan}{t}\) 屏蔽每个token \(\textcolor{DarkCyan}{x_t}\),并从过去的 \(\textcolor{DarkCyan}{x_{1:t-1}}\)进行预测。仅限于序列,该做法称为沿时间轴的掩蔽。作者还可以将全序列前向扩散,即逐渐向数据 \(\textcolor{DarkCyan}{x^0_{1:T} ≡ x_{1:T}}\) 中增加噪声,作为部分掩蔽的一种形式,将其称为沿噪声轴的掩蔽。事实上,在 \(\textcolor{DarkCyan}{K}\) 步噪声之后,\(\textcolor{DarkCyan}{x^{K}_{1:T}}\)是纯白噪声,没有关于原始数据的信息。
作者沿着掩蔽的两个轴建立了一个统一的视图(见图2)。 \(\textcolor{DarkCyan}{x_{1:T}}\) 表示标记序列,其中下标表示时间轴。如上所述,\(\textcolor{DarkCyan}{x_t^{k_t}}\)表示在前向扩散过程下具有噪声水平\(\textcolor{DarkCyan}{k_t}\)的\(\textcolor{DarkCyan}{x_t}\); \(\textcolor{DarkCyan}{x_t^0 = x}\) 是无噪声token,\(\textcolor{DarkCyan}{x_t^K}\) 是白噪声 \(\textcolor{DarkCyan}{N(0,I)}\)。因此,\(\textcolor{DarkCyan}{(x_t^{k_t})_{1≤t≤T}}\) 表示一系列噪声观测值,其中每个token具有不同的噪声水平 \(\textcolor{DarkCyan}{k_t}\),这可以看作是通过噪声对每个token施加的部分掩蔽程度。
图 2:方法概述。扩散强制训练因果序列神经网络(例如 RNN 或 masked transformer)对灵活长度序列进行去噪,其中序列的每一帧都可以具有不同的噪声水平。相比之下,在语言建模中常见的下一个token预测模型被训练为从真值序列中预测单个下一个token(教师强迫),而在视频生成中常见的全序列扩散则训练非因果架构以相同的噪声水平同时对序列中的所有帧进行去噪。因此,扩散强迫交错了序列的时间轴和扩散的噪声轴,统一了两种替代方案的优势,并实现了全新的功能。
扩散强制:不同token的不同噪声节点。扩散强迫 (DF) 是一个用于训练和采样噪声token \(\textcolor{CornflowerBlue}{(x_t^{k_t})_{1≤t≤T}}\) 的任意序列长度的框架,其中至关重要的是,每个token的噪声水平 \(\textcolor{CornflowerBlue}{k_t}\) 可能随时间步长而变化。对于时间序列数据,通过因果架构(其中 \(\textcolor{CornflowerBlue}{x^{k_t}_t}\) 仅依赖于过去的噪声标记)实例化扩散强迫,称之为因果扩散强迫(CDF)。为简单起见,作者专注于使用普通递归神经网络(RNN)的最小实现。
具有权重为\(\textcolor{CornflowerBlue}{\theta}\) 的 RNN 通过动态 \(\textcolor{CornflowerBlue}{z_t ∼ p_θ(z_t|z_{t−1}, x^{k_t}_t, k_t)}\) 与循环层一起演化,保持潜在 \(\textcolor{CornflowerBlue}{z_t}\) 捕捉过去token的影响,这些。当进行传入的噪声观测 \(\textcolor{CornflowerBlue}{x^{k_t}_t}\)时,隐藏状态以马尔可夫式 \(\textcolor{CornflowerBlue}{z_t ∼ p_θ(z_t|z_{t−1}, x^{k_t}_t, k_t)}\) 进行更新。当 \(\textcolor{CornflowerBlue}{k_t = 0}\) 时,这是贝叶斯滤波的后验更新;而当 \(\textcolor{CornflowerBlue}{k_t = K}\) 时(\(\textcolor{CornflowerBlue}{x_t^K}\) 是纯噪声,因此没有信息),这相当于在贝叶斯滤波中建模“先验分布” \(\textcolor{CornflowerBlue}{p_θ(z_t|z_{t−1})}\)。给定潜在 \(\textcolor{CornflowerBlue}{z_t}\),观测模型 \(\textcolor{CornflowerBlue}{p_θ(x_t^0|z_t)}\) 预测 \(\textcolor{CornflowerBlue}{x_t}\);该单元具有与标准条件扩散模型相同的输入输出行为,使用条件变量 \(\textcolor{CornflowerBlue}{z_{t-1}}\) 和噪声令牌 \(\textcolor{CornflowerBlue}{x^{k_t}_t}\) 作为输入来预测未被加噪的 \(\textcolor{CornflowerBlue}{x_t = x_t^0}\),从而间接地通过仿射重新参数化预测噪声 \(\textcolor{CornflowerBlue}{ε^k_t}\) 。因此,可以直接使用传统的扩散训练目标来训练(因果)扩散强迫。根据噪声预测 \(\textcolor{CornflowerBlue}{ε_θ(z^{t−1}, x^{k_t}_t, k_t)}\) 对上述单位进行参数化。然后,通过最小化损失来找到参数 \(\textcolor{CornflowerBlue}{θ}\)
其中,从 \(\textcolor{CornflowerBlue}{[K]^T}\) 中均匀采样 \(\textcolor{CornflowerBlue}{k_{1:T}}\),从训练数据中均匀采样 \(\textcolor{CornflowerBlue}{x_{1:T}}\),根据前向扩散过程,\(\textcolor{CornflowerBlue}{ε_t ∼ N(0,σ^2_{k_t}I)}\)(参见伪代码的算法 1)。重要的是,损失 (3.1) 捕获了贝叶斯滤波和条件扩散的基本元素。
1.扩散强制采样和能力
采样在算法 2 中描述,通过在 2D \(M × T\) 网格 \(\textcolor{blue}{\mathcal{K} \in [K]^{M×T}}\) 上规定噪声时间表来定义; 列对应于时间步长 \(\textcolor{blue}{t}\),由 \(\textcolor{blue}{m}\) 索引的行确定噪声级别。\(\textcolor{blue}{\mathcal{K}_{m,t}}\) 表示行 \(\textcolor{blue}{m}\) 的时间步长 \(\textcolor{blue}{t}\) 标记所需的噪声水平。要生成长度为 \(\textcolor{blue}{T}\) 的整个序列,将标记 \(\textcolor{blue}{x_{1:T}}\) 初始化为白噪声,对应于噪声水平 \(\textcolor{blue}{k = K}\)。作者逐行迭代网格,从左到右对列进行去噪,以达到 \(\textcolor{blue}{\mathcal{K}}\) 规定的噪声水平。到最后一行 \(\textcolor{blue}{m = 0}\),token是干净的,即它们的噪声水平为 \(\textcolor{blue}{\mathcal{K_{0,t} ≡ 0}}\)。
稳定自回归生成:对于高维、连续的序列(如视频),已知自回归架构会发散,尤其是在采样超过训练范围时。相比之下,扩散强迫可以通过使用与略微“加噪的tokens”相关的前潜在来更新潜在源,从而稳定地推出长序列,甚至可以在训练序列长度之外推出一些小的噪声级别 \(\textcolor{blue}{0 < k ≪ K} \)。
让未来充满不确定性:从一系列白噪声tokens\(\textcolor{blue}{[x_1^K, x_2^K, x_3^K]^⊤}\) 开始,我们可以对第一个token进行完全去噪,对第二个token进行部分去噪,得到 \(\textcolor{blue}{[x^0_1, x^{K/2}_2, x_3^K]^⊤}\),然后得到 \(\textcolor{blue}{[x^0_1, x^0_2, x^{K/2}_3 ]^⊤}\),最后将所有token完全去噪到 \(\textcolor{blue}{[x_1^0, x_2^0, x_3^0]^⊤}\)。这种“锯齿形”采样方案将噪声水平解释为不确定性,直观地将近期的未来编码为比遥远的未来更确定。
长期指导:在算法 2 的第 10 行中,可以向部分扩散的轨迹 \(\textcolor{blue}{x_{1:T}}\)添加引导,如第 2 节所示。由于未来token对过去的依赖性,未来token的引导梯度可能会在时间上向后传播。扩散强迫的独特优势在于,由于我们可以在不完全扩散过去的情况下扩散未来的token,因此梯度指导了过去token的采样,从而在尊重因果关系的同时实现了长期引导。
2.用于灵活序列决策的扩散强迫
扩散强制提供的功能激发了新的序列决策框架(SDM),该框架在机器人和自主代理中具有关键应用。考虑一个由具有动作 \(\textcolor{blue}{p(s_{t+1}|s_t, a_t)}\)、观测 \(\textcolor{blue}{p(o_t|s_t)}\) 和奖励 \(\textcolor{blue}{p(r_t|s_t, a_t)}\) 的环境定义的马尔可夫决策过程。目标是训练策略 \(\textcolor{blue}{π(a_t|o_{1:t})}\),使得轨迹 \(\textcolor{blue}{E[\sum^T_{t=}r_t]}\) 的预期累积奖励最大化。我们分配token \(\textcolor{blue}{x_t = [a_t, r_t, o_{t+1}]}\)。轨迹是一个序列 \(\textcolor{blue}{x_{1:T}}\),可能是可变长度的;训练按照算法 1 进行。在执行的每一步 \(\textcolor{blue}{t}\) 中,过去的(无噪声)token \(\textcolor{blue}{x_{1:t-1}}\) 由一个潜在的 \(\textcolor{blue}{z_{t-1}}\) 汇总。基于这个潜在因素,我们通过算法 2 对一个计划 \(\textcolor{blue}{\hat{x}_{t:t+H}}\)进行采样,其中 \(\textcolor{blue}{\hat{x}_t = [\hat{a}_t, \hat{r}_t, \hat{o}_{t+1}]^⊤}\) 包含预测的行动、奖励和观察结果。\(\textcolor{blue}{H}\)是前瞻窗口,类似于模型预测控制中的未来预测。在采取计划的行动 \(\textcolor{blue}{\hat{a}_t}\) 后,环境产生奖励 \(\textcolor{blue}{\hat{r}_t}\) 和下一个观察值 \(\textcolor{blue}{\hat{o}_{t+1}}\),产生下一个token \(\textcolor{blue}{x_t = [\hat{a}_t, r_t, o_{t+1}]^⊤}\)。潜在值根据后验 \(\textcolor{blue}{p_θ(z_t|z_{t−1}, x_t, 0)}\) 进行更新。该框架支持作为策略和规划器的功能:
灵活的规划范围。扩散强迫 (a) 可以部署在可变范围的任务上,因为每个新动作都是按顺序选择的,并且 (b) 其前瞻窗口 \(\textcolor{blue}{H}\) 可以缩短以降低延迟(使用扩散强迫作为策略),或者延长以执行长期规划(通过下面描述的指导),而无需重新训练或修改架构。请注意,(a)对于具有全轨迹生成视野的Diffuser 这样的全序列扩散模型是不可能的,而扩散策略需要固定的、小的预测尺寸,排除了(b)。
蒙特卡洛树指导(MCTG)。因果扩散强迫使我们能够通过指导未来 \(\textcolor{blue}{x_{t+1:T}}\)的整个分布来影响token \(\textcolor{blue}{x_t^k}\) 的生成。我们可以绘制多个样本并平均它们的引导梯度,而不是绘制单个轨迹样本来计算此引导梯度。我们称之为蒙特卡洛树指导,其中“树”来自这样一个事实,即当前 \(\textcolor{blue}{x_t^k}\) 的去噪受到未来许多路径的梯度的影响,\(\textcolor{blue}{x_t^k}\)随后受到预期奖励的指导,而不是一个特定的结果。当与采样计划相结合时,MCTG 的效果会增强,这些计划在对下一个代币进行去噪时保持未来代币的高噪声水平。
三、实验
评估了扩散强制作为生成序列模型在视频和时间序列预测、规划和模仿学习中的各种应用中的优点。
1.视频预测:一致、稳定的序列生成和无限扩展
作者训练了因果扩散强迫的卷积RNN实现,用于在Minecraft游戏和DMLab导航的视频上进行视频生成建模。在采样时,执行自回归推出,并采用第 3.1 节中提出的稳定功能。作者设计的两个基线,它们都利用了相同的精确RNN架构:一个是用Teacher Forcing训练的下一帧扩散基线,另一个是因果全序列扩散模型。图 3 显示了由扩散强迫生成的推出的定性结果,以及从两个数据集的不可见帧开始的基线。
图 3:视频生成。在经过测试的方法中,扩散强制生成在时间上是唯一一致的,即使在训练范围之后进行扩展也不会发散。
扩散强制能够成功并且稳定的扩展视频长度,甚至远远超出其训练范围(例如 1000 帧),而教师强制和全序列扩散的基线会迅速发散。此外,在训练范围内,我们观察到全序列扩散受到帧到帧不连续性的影响,其中视频序列急剧跳跃,而扩散强迫推出在一致的 3D 环境中显示自我运动。这凸显了扩散强制能够在不产生复合误差的情况下稳定高维序列的推出。
2.扩散规划:MCTG、因果不确定性、灵活的地平线控制。
决策独特地受益于 Diffusion Forcing 的能力。作者在标准的离线RL基准D4RL 中评估了作者提出的决策框架。具体来说,在一组具有稀疏奖励的 2D 迷宫环境中对扩散强制进行了基准测试。智能体的任务是从随机的起始位置开始达到指定的目标位置。该基准提供了一个通过迷宫随机行走的数据集(因此是随机的)。
作者他们将提出的决策框架 与最先进的离线 RL 方法和最近推出的扩散规划框架 Diffuser 进行对比。见图 4 的定性和定量结果:DF 在所有 6 种环境中都优于 Diffuser 和所有基线。
图 4:用于规划的扩散强迫。(返回顶部)在采样过程中,扩散强迫允许每个时间步长在不同的噪声计划中进行去噪,使我们能够在引导规划期间考虑因果不确定性。扩散强迫使遥远的未来比近期的未来更具不确定性,而扩散器在采样过程中使它们处于相同的噪声水平。(底部)从数量上讲,扩散强迫在运行中实现了最高的平均奖励。Diffuser 在执行实际生成的动作时会严重失败,需要手工制作的 PD 控制器(用星号表示)并忽略生成的动作。
蒙特卡洛树指导的好处。RL 问题的典型目标是找到能够最大化预期未来奖励的行动,作者通过 MCTG 实现了这些奖励。全序列扩散模型(如 Diffuser)不支持抽样以最大化预期奖励,为了理解 MCTG 的重要性,在表 1 中对其进行了消融。删除 MCTG 指导会降低该方法的性能,尽管即使在那时,扩散强迫仍然具有竞争力。
对因果关系进行建模的好处。与纯粹的生成建模不同,顺序决策会采取行动并接收反馈。由于不确定性的增加,眼前的下一步行动比遥远的将来的行动更重要。尽管 Diffuser 和后续模型经过训练以生成动作-奖励-状态元组序列 \(\textcolor{blue}{[a_t, r_t, o_t]}\),但直接执行动作将导致轨迹明显偏离生成的状态。换言之,生成的状态和动作在因果关系上彼此不一致。为了解决这个缺点,Diffuser 的实现忽略了生成的动作,而是依赖于手工制作的 PD 控制器来从生成的状态推断动作。在表 1 中,我们看到 Diffuser 在直接执行生成的操作时性能急剧下降。相比之下,Diffusion Forcing 的原始动作生成是自洽的,甚至优于通过将 Diffuser 的状态预测与手工制作的 PD 控制器相结合而选择的动作。
灵活视野的好处。许多 RL 任务具有固定的跨度,需要随着代理在任务中取得进展而缩小计划跨度。扩散强迫通过设计实现了这一点,而像扩散器这样的全序列模型即使进行了调整也表现不佳。
3.可控的顺序组合生成
由于扩散强迫对序列的任何子集的联合分布进行建模,因此可以利用这一独特的属性来实现组合行为。即,扩散强迫可以从轨迹子集的分布中采样,并将这些子轨迹组合成新的轨迹。
如 7 所示,考虑一个 2D 方形平面上的轨迹数据集,其中所有轨迹从一个角开始,结束于另一个角,形成一个十字形。当不需要组合行为时,可以通过允许 HMM 模型的完全内存来让模型复制十字形分布。当人们想要组合时,例如生成 V 形轨迹,将两个子轨迹拼接在一起,则可以让模型使用 MPC 生成具有无内存上下文的较短计划。
图 7:给定一个轨迹数据集 (a),扩散强制模型(Diffusion Forcing)对任意长度的所有子序列进行联合分布建模。在采样时,我们可以通过全范围采样扩散强制模型(b)从轨迹分布中进行采样,或者通过忽略之前的状态(c)来恢复马尔可夫动态。
4.机器人技术:长视野模仿学习和强大的视觉运动控制
最后,作者还说明了扩散强制(DF)为真实世界机器人的视觉运动控制开辟了新的机会。模仿学习是机器人操作中一种流行的技术,人们从专家的演示中学习观察到行动的映射。然而,记忆力的缺乏往往会阻止模仿学习完成长期任务。DF 不仅缓解了这一缺点,而且还提供了一种使模仿学习鲁棒的方法。
用记忆模仿学习。通过远程操作 Franka 机器人来收集视频和动作的数据集。在选定的任务中,需要使用第三个插槽交换苹果和橙子的位置。图示见图 5。水果的初始位置是随机的,因此有两种可能的目标状态。如图 5 所示,当一个水果位于第三个插槽中时,无法从当前观察中推断出所需的结果 - 策略必须记住初始配置以确定要移动哪个水果。与常见的行为克隆方法相比,DF 自然地将内存纳入其潜在状态。我们发现 DF 实现了 80% 的成功率,而扩散策略 [10](一种没有内存的最先进的模仿学习算法)失败了。
图 5:在真实机器人任务中,机器人手臂被要求使用第三个插槽交换两个水果的插槽。由于水果在开始时是在随机插槽中输入的,因此如果不了解水果的初始放置位置,就无法从单个观察中确定下一步。如(a)和(b)所示,上层观测值是相同的,但下图所示的期望结果可能会有所不同,因此任务需要记住初始配置。此外,如(c)所示,生成动作的同一模型也仅从单帧合成逼真的视频。
对缺失或嘈杂观测值的鲁棒性。因为它融合了贝叶斯滤波的原理,所以扩散强迫可以执行模仿学习,同时对嘈杂或缺失的观察结果具有鲁棒性。我们通过增加视觉干扰,甚至在执行过程中完全遮挡摄像头来证明这一点。DF 允许我们通过使用 k > 0 轻松地将这些观测值指示为“噪声”,在这种情况下,DF 严重依赖其先前的模型来预测动作。因此,成功率仅降低 4% 至 76%。相比之下,下一帧扩散模型基线的成功率为 48%:它必须将扰动观测值视为基本事实,并遭受分布外误差。
通过视频进行预训练的潜力。最后,在生成动作的同时,图 4 说明了扩散强迫能够生成机器人执行任务的视频,该视频仅给定初始帧,统一扩散策略/模仿学习和视频生成建模,并为在未标记视频上进行预训练铺平了道路。
四、总结
在本文中,作者引入了扩散强制,这是一种新的训练范式,其中训练模型以对具有独立、每个token噪声水平的标记集进行去噪。应用于时间序列数据,作者展示了使用扩散强制训练的下一个token预测模型如何结合下一个token模型和全序列扩散模型的优点。并且引入了新的采样和指导方案,当应用于顺序决策中的任务时,可以显著提高绩效。
局限性。目前的因果实现基于小型 RNN,用于更高分辨率视频或更复杂分布的应用可能需要大型 transformer 模型。