PhoenixPeng's blog.

BFNs——Bayesian Flow Networks

2023/10/14

一、简要

  贝叶斯流网络(BFNs)是一种新型生成模型。它的输入和其他模型相比不再是数据本身,而是数据所服从分布的参数。这意味着不管丢给BFNs是连续型数据(图像)也好,还是离散型数据也好(文本)。它的输入都是连续,可微分。比如,对于离散的多类别数据,其服从类别分布,但该分布的参数依然是连续的实值。再由于网络的输入始终是数据分布的参数,因此天然地位于 probability simplex(概率单纯形) 上,从而具有原生的可微性,为语言建模等离散领域中基于梯度的样本引导和少步生成铺平了道路。

  综合来说,BFN 对于无论是对于连续数据(continuous data),离散化数据(discretised data) 亦或是 离散数据(discrete data) 都得心应手,能够天然地适配不同类型的数据与分布。因此BFNs这个网络不仅适应于图像生成任务,也能够适应于文本生成任务。并且在text8字符级语言建模任务上打败现在已知的离散扩散模型。

  BFN 结合了传统统计学与深度学习的优势,前者通过贝叶斯推断来体现,在数学上提供了保障性与可解释性;而后者则利用了神经网络的玄学能力——能够对高维空间中变量之间的复杂交互关系进行有效建模。

二、原理介绍

1.现有模型的建模方式

  作者认为,使用深度神经网络的现代生成模型是通过对图像所有像素的联合分布进行建模来生成高分辨率图像的,这些模型包括:autoregressive models, flow-based models, deep VAEs & diffusion models,它们成功的关键在于将联合分布的编码过程“分解”为一系列的步骤,从而避免了“维度灾难”——在高维度情况下,所有变量之间的直接交互会变得十分复杂和可怕。

  虽然在细节上,以上模型的做法不尽相同,但一种通用视角是将它们建模的过程均看作是“最优比特传输问题”:发送者Alice(有权访问某些数据) 向她的好朋友 Bob(作为消息的接收者,并且希望Bob接收到的消息的位数尽可能的少)传输多轮经过编码的消息(经过压缩的消息),而这些消息里面揭示了有关原数据的一些信息。Bob在接受消息之前,尝试猜测消息是什么:他猜测得越好,传输消息所需的位数就越少。收到消息后,Bob 使用刚刚获得的信息来改进对下一条消息的猜测策略,使得以后能够猜得更准。当经过多轮传输之后,Bob得到了最佳猜测策略,使得他能够猜测出原始数据的信息,对于生成模型而言它猜的越准,其建模分布越接近真实数据分布。loss 函数定义为传输所有这些消息所用的总比特(bits)数,并且规定 receiver 猜得越准,loss 就越小。

  例如,在自回归语言模型中,消息是文本被划分为的单词片段。 Bob 对第一条消息的预测的分布编码必然是无知的。传输成本是该先验下的负对数概率。然后鲍勃使用第一个单词片段来预测第二个单词片段;平均而言,第二个预测将比第一个预测稍微更明智,并且预期的传输成本将略低。重复该过程,每一步的预测都会得到改进。传输成本之和是完整文本序列的负对数概率,它是通过最大似然训练最小化的损失函数。这也是 Alice 使用算术编码将片段传输给 Bob 所需的最小位数 。因此,以最大似然拟合自回归模型与训练数据压缩之间存在直接的对应关系。

  自回归模型目前是语言领域的SOTA,但在图像生成领域却没有表现出什么出色的效果。与离散形式的语言不同,图像是连续形式的数据,并且像素之间不存在天然的顺序关系——对于图像中的像素,似乎没有什么理由要求一定要先生成某个像素后再生成另一个,而一句话中的各个词之间的顺序却是有严格意义的。

  反观扩散模型,在图像生成领域十分出色,以一种更自然地方式来生成图像——这是一种去噪的过程,从一张纯噪声图开始,通过不断减少噪声来生成图像,Bob 收到的每条消息都是之前消息的噪声版本,其中噪声的设计是为了使消息在预期中接近数据。每一步的传输成本是 Alice 从中提取消息的分布与 Bob 对该分布的预测之间的 KullbackLeibler 散度。

  作者认为,在图像生成方面,扩散相对于自回归的优越性在于,随着噪声水平的降低,扩散从粗略的图像细节进展到精细的图像细节——这是一种比一次一个点更自然的构建图像的方式。但是,对于离散形式的数据,DM加噪的过程中的噪声也是离散的,而模型在去噪过程中估计的噪声天然是连续形式的,于是影响了去噪效果,导致其通常打不过自回归模型。作者认为这很可惜,毕竟扩散模型”解耦“了变量(比如像素)数与生成步数,也就是不需要像自回归模型那样逐个变量生成。这项工作的一个关键动机是我们相信完全连续的传输过程(爱丽丝的消息平稳地改变鲍勃的信念)对于离散数据会更有效。

2.BFNs的建模方式

  本文介绍的贝叶斯流网络 (BFNs) 模型与扩散模型的不同之处在于,该网络对数据分布的参数进行操作,而不是对数据本身的噪声版本进行操作。

  1️⃣回到Bob和Alice,其中Bob会猜测原始数据分布是什么样的,于是乎Bob设置一个先验(prior),并假设数据服从这种分布(各维度上的变量独立同分布),作者将其命名为"input distribution"(输入分布)。对于连续型数据,prior 设为标准高斯分布;而对于离散型数据,prior 则设为各类别概率均等的类别分布

  2️⃣在每轮信息的传输过程中,它讲输入分布的参数(例如正态分布的平均值,类别分布的概率)送到神经网络当中,神经网络输出第二个分布的参数,作者名为"output distribution"(输出分布)

  3️⃣然后,Alice通过根据预定义的时间表向数据添加噪声来创建"Sender distribution"(发送者分布)并从中采样,采样结果就作为消息发送给 Bob,这些消息就是 Bob 的观测样本。 

  4️⃣Bob通过模仿Alice对输出分布使用的相同噪声分布进行加噪来创建"receiver distribution"(接收方分布)。他先从输出分布中采样出样本,然后再对样本加噪。由于以单个输出样本猜中 Alice 手上的原始数据的可能性较低,因此 Bob 采取了暴力大法 —— 他采样了输出分布的所有可能的结果(相当于穷尽了分布中所有类别的样本),并且对它们加噪后的结果加权求和,以此作为对观测样本的猜测结果。因此发送者分布是观测样本真实服从的分布,而接收者分布是 Bob 在观测到样本前自己猜测的观测样本可能会服从的分布。如果是训练时Alice是知道原始数据真实服从的分布,但是在建模采样生成时,Alice和Bob都是不知道原始数据真实服从的分布。所有才要让Bob去逐步学会如何猜测信息。

  5️⃣Alice 从发送者分布中挑选一个样本,并将其发送给 Bob,其成本等于接收者和发送者的 KL 散度。然后,Bob 使用该样本遵循贝叶斯推理规则来更新他的输入分布。

  6️⃣更新完成后,Bob 再次将输入分布的参数提供给网络,网络返回输出分布的参数。于是,Bob 猜得越准,则接收者分布对发送者分布就拟合得越好,从而 Alice 传输消息所用的比特数就越少,代表数据压缩率越高。

  7️⃣该过程重复 n 步骤,此时 Bob 可以足够准确地预测数据,从而最终Alice 直接将原始数据(不带噪声)发送给他,Bob也能够猜中,而不会产生任何噪音。

3.输入分布与输出分布的作用

  由于输入分布独立地接收有关数据中每个变量的信息(通过贝叶斯更新),因此是不包含上下文信息,例如图像中的相邻像素或文本中的相关单词;

  另一方面,输出分布是由神经网络产生的,该神经网络联合处理输入分布中的所有参数,使其能够访问所有可用的上下文。

  比如对于一幅图像中的所有像素,在它们共同组成该图像前,都拥有各自独立的个性,这种个性就由贝叶斯大法来负责研究;而它们“聚在一起”组成这幅图像时,就需要“相互协调分工”(像素 A 负责当前景,像素 B 负责做背景),于是它们之间就存在一些协作关系,相互关系通常比较复杂,这种关系就通过深度学习来捕捉。

  由此我们也可以直观地感受到,BFN 结合了贝叶斯推断和深度学习的优势:前者提供了一种在数学上最优且精细可控的方式来收集和归纳关于数据中各变量的独立信息;而后者则擅长于整合变量之间的相互关系和找出它们之间的交互规律。

三、BFNs与扩散模型

  虽然贝叶斯流网络与扩散模型类似,但是也存在如下区别:

  1️⃣BFN 体现的是从一种分布到另一种分布的函数,而不是像扩散模型那样从数据到分布。这个方法带来的优势就是输入的不管是连续数据还是离散数据,形式始终都是连续且可微的比如,对于离散的多类别数据,其服从类别分布,但该分布的参数依然是连续的实值。

后者在应对离散数据时,要么是在离散的数据空间与连续的 embedding 空间之间做映射,要么就是将连续扩散的过程约束到 probability simplex 中,限制了其生成效果。

  而 “连续性”(continuity) 可谓是 BFN 內在的固有屬性, 这得益于它的输入是概率分布的参数,能够很自然地应对离散形式的数据,也就无需对现有的系统做约束。比如还是对于离散的多类别数据,BFN 的输入是类别分布的参数,是连续的实值(输出也是),但最终采样出来的结果却能够自然地处于 probability simplex 中。

  在应对离散数据时,BFN 的这种天然基因还减少了参数的设计空间(无需设计离散数据空间与连续 embedding 空间之间的映射函数),能够直接对数据的负对数似然进行优化,而以往基于 diffusion models 的方法通常需要设计简化的 loss 函数或添加辅助 loss 项以稳定训练。

  2️⃣BFN 中的网络输入比变分扩散和其他连续扩散模型中的噪声要少得多。这是因为 BFN 的生成过程从固定先验的参数开始,而扩散模型的生成过程从纯噪声开始,没有考虑真实数据的分布。

  3️⃣BFN 无需像 diffusion models 那样设计一个前向(扩散)过程并且通过逆转它来实现生成,从而能够更方便地适配到不同类型的分布与数据

  总的来说,BFN 的天然优势可归纳为以下几点:

  1️⃣输入是连续可微的,能自然地应对离散数据;

  2️⃣灵活地适配各种类型的数据与分布;

  3️⃣为基于梯度指导的采样过程和少步生成提供了可能

  4️⃣能够直接优化似然函数,可以对数据的概率密度值进行估计;

  5️⃣less noisy inputs 使其在大数据集上收敛更快(暂未有实验支撑,纯属作者脑洞)。

四、贝叶斯流网络(BFNs)

1.输入分布

  给定D维数据\(\textcolor{blue}{\bf{x}}\) \(\textcolor{blue}{=(x^{(1)},x^{(2)},...,x^{(D)}) \in \mathcal{X}^D}\),设\(\textcolor{blue}{\theta = (\theta^{(1)},\theta^{(2)},...,\theta^{(D)})}\)为数据\(\textcolor{blue}{x}\)所服从的分布参数,则输入分布定义为:

  例如,\(\textcolor{blue}{θ^{(d)}}\)可以由分类分布的概率组成。根据上式,数据\(\textcolor{blue}{x}\)在每个维度上的变量\(\textcolor{blue}{x^{(d)}}\)仅由对应维度上的分布参数\(\textcolor{blue}{θ^{(d)}}\)决定,\(\textcolor{blue}{x}\)的概率密度即所有维度上的变量的联合概率密度\(\textcolor{blue}{p_I(x|\theta)}\), 等于各维度变量的概率密度\(\textcolor{blue}{p_I(x^{(d)}|\theta^{(d)})}\)的乘积, 说明每个维度上的变量都是独立的; 同时, \(\textcolor{blue}{\theta}\)在每个维度上的变量\(\textcolor{blue}{θ^{(d)}}\)代表的都是同一种分布。

注:这里的D维数据x和\(\theta\),不是来自原数据的信息以及分布参数,而\(\theta\)可以通过人为预先设定好。

2.发送者分布(sender distribution)

  发送者分布就是在原始数据上加噪后所得到的分布,也就是观测样本所服从的真实分布。

  但Alice的加噪方式与扩散模型不同的是: 不是直接对每一轮信息传输的噪声方差进行设置, 而是在每一轮对应设立一个称为“精度(accuracy)"的参数\(\textcolor{blue}{\alpha \in R^+}\), 来表示加噪后的数据\(\textcolor{blue}{y}\)与原始数据\(\textcolor{blue}{x}\)的相关(接近)程度。其定义为当 \(\textcolor{blue}{\alpha = 0}\) 时,发送者样本完全不包含有关 x 的信息,并且随着 \(\textcolor{blue}{\alpha}\) 的增加\(\textcolor{blue}{\alpha}\)越大, \(\textcolor{blue}{y}\)的噪声程度越低, 其所含的与\(\textcolor{blue}{x}\)相关的信息量就越多。

  给定\(\textcolor{blue}{\bf{y}}\) \(\textcolor{blue}{ = y^{(1)}, . . . , y^{(D)} \in \mathcal{Y}^D}\),发送者分布就定义为:

  在发送者分布中,各维度上的变量也是相互独立而不发生交互的。在整个游戏过程中,输入分布仅接收发送者分布馈入的信息(通过贝叶斯更新过程),从而保证了前者在各维度上的变量也没有交互。

3.输出分布

  输出分布是由 BFN 以输入分布的参数\(\textcolor{blue}{\theta}\)为输入而产生的,此外还加入了时间变量\(\textcolor{blue}{t}\)作为输入,以区分 Alice 和 Bob 的每一轮信息传输,这会更好地帮助 Bob 调控输出分布——因为 Alice 在每轮传输的信息方差都不一样,所以 Bob 也要灵活地对应调整,这样才能猜得更准。

  在数据传输过程中,输入参数\(\textcolor{blue}{\theta}\)与处理时间\(\textcolor{blue}{t}\)一起作为神经网络BFN(记为\(\textcolor{blue}{Ψ}\))的输入传递。然后网络发出输出向量\(\textcolor{blue}{Ψ(θ, t) = (Ψ^{(1)}(θ,t),...,Ψ^{(D)}(θ, t))}\)用于参数化输出分布,其定义为:

  输入和输出分布之间的主要区别在于,虽然每个\(\textcolor{blue}{x^{(d)}}\)仅由对应维度的\(\textcolor{blue}{Ψ^{(d)}(\theta,t)}\)决定,但后者却由所有维度的\(\textcolor{blue}{\theta^{(d)}}\)\(\textcolor{blue}{\theta}\)产生,因此 \(\textcolor{blue}{x}\)在每个维度上的变量都与其他维度发生了交互。因此,与输入分布不同,输出分布可以利用上下文信息,例如图像中的周围像素或文本中的相关单词。

4.接收者分布(receiver distribution)

  receiver分布是 Bob 在输出分布上加噪后的结果,以此作为对观测样本所服从的分布的估计,定义为:

  直观上,可以理解为Bob想要模仿Alice对原数据进行加噪一样,对输出分布所得到的\(\textcolor{blue}{\bf{x^\prime}}\)进行加噪,并模仿Alice同样的加噪形式,即receiver知道sender分布\(\textcolor{blue}{p_S(\cdot |x;\alpha)}\) 的形式,但不知道 \(\textcolor{blue}{x}\),就像前面的Bob无权访问原数据。因此Bob得到基于\(\textcolor{blue}{\bf{x^\prime}}\)加噪后所有可能的结果,并通过输出分布 \(\textcolor{blue}{p_O (x|θ,t)}\) 赋予\(\textcolor{blue}{\bf{x^\prime}}\)的概率进行加权求和,以此作为对观测样本的猜测结果,其中每个样本的权重就是它们的采样概率。由此可见,这个猜测结果是(Bob 构造的)噪声样本在输出分布上的期望

  结合以上四种分布的具体形式,可以对BFN的流程有个更加清晰的了解:

5.贝叶斯更新

  贝叶斯更新指的是根据贝叶斯推断来计算后验概率,以更新(校正)先验参数的过程。 放到这场游戏中,就是 Bob 在收到 Alice 传来的消息\(\textcolor{blue}{y}\)后利用该观测样本来更新输入分布的参数\(\textcolor{blue}{\theta}\) ,设一个“贝叶斯更新函数”来表示这个更新过程:

  然而, \(\textcolor{blue}{\theta}\)是个随机变量, 以上函数仅仅基于一个噪声样本计算出对应的一个值, 代表更新过程的其中一种可能性(不同的\(\textcolor{blue}{\theta,y}\)会得到不同的\(\textcolor{blue}{\bf{\theta^\prime}}\) , 它们的取值都有概率性), 更科学地应该是要计算出其服从的分布。

  于是, 通过边缘化\(\textcolor{blue}{y}\)来定义贝叶斯更新分布,即考虑了所有噪声样本的可能性:

  其中\(\textcolor{blue}{δ(·−a)}\)是以向量\(\textcolor{blue}{a}\)为中心的多元狄拉克\(\textcolor{blue}{δ}\) 分布。

  对于连续型和离散型数据,贝叶斯更新分布均具有“精度可加性”,即如果\(\textcolor{blue}{\alpha = \alpha_a + \alpha_b}\),那么:

  从上式可以得出,给定先验输入参数\(\theta_0\),以及n个发送者样本序列:\(\textcolor{blue}{\bf y_1,y_2,...,y_n}\),其对应的精度:\(\textcolor{blue}{\alpha_1,\alpha_2,...,\alpha_n}\),得到观测后更新最终输入参数\(\textcolor{blue}{\theta_n}\)的概率为:

  上述结论会对于连续和离散型数据的具体实现来进行说明。

  看着很复杂,但是总的一看,第一反应是感觉有点像扩散模型的前向过程:

6.精度表\(\beta(t)\)

  每一轮信息传输由时间变量\(\textcolor{blue}{t}\)表示, 而精度在每轮是不同的, 于是可将其看作是时间变量的函数:\(\textcolor{blue}{\alpha(t)>0 ,t \in [0,1]}\), 这样还可以同时兼容离散和连续时间的情况:在离散时间步的情况下,假设总步数为\(\textcolor{blue}{n}\) , 当前步为\(\textcolor{blue}{i(i=1,2,...,n)}\) , 则令\(\textcolor{blue}{t_i = \frac{i}{n}}\) ; 而在连续时间的情况下, \(\textcolor{blue}{t}\)可以直接从\(\textcolor{blue}{[0,1]}\)区间中均匀采样。

  现在将精度表 \(\textcolor{blue}{β(t)}\) 定义为:

  它是精度从起始时刻到当前时刻的积分, 是\(\textcolor{blue}{t}\)的单调递增函数。

  于是精度的计算方法为:\(\textcolor{blue}{\alpha(t) = \frac{dβ(t)}{dt}}\)

  也就是说, 无论是连续还是离散时间步的情况, 时间变量\(\textcolor{blue}{t}\)都是位于\(\textcolor{blue}{[0,1]}\)区间内连续的小数。 只不过在离散时间步下,我们只能取到\(\textcolor{blue}{t}\)在离散点上对应的取值: \(\textcolor{blue}{t_1,t_2,...,t_n}\), 相当于只能在连续曲线上抓取其中的一些点,而不能获得完整的曲线。

  这里的 accuracy schedule 类似于 diffusion models 中对噪声方差进行设置,因为它控制着发送者和接收者分布的方差,也就是噪声程度。 另外,也类似于 diffusion models 的噪声方差设置,在实际应用时,\(\textcolor{blue}{β(t)}\)是根据某种策略人为设定,而非计算积分而来。

7.贝叶斯流分布\(p_F(\cdot | x;t)\)

  贝叶斯流分布是当前时刻的贝叶斯更新分布的边缘分布,它对 BFN 的输入参数进行边缘化:

  贝叶斯更新分布对于\(\textcolor{blue}{\theta}\)\(\textcolor{blue}{t}\)的边缘化分布可以仅由\(\textcolor{blue}{\theta_0}\)\(\textcolor{blue}{\beta(t)}\)决定。 \(\textcolor{blue}{\theta_0}\) 、\(\textcolor{blue}{\beta(t)}\)都是人为设定的, 而数据\(\textcolor{blue}{\theta}\)是明确知道的, 于是这可以很方便地推导出每个时刻的贝叶斯流分布, 有点类似于 diffusion models 的扩散(前向)过程一一可以从原始图像\(\textcolor{blue}{x_0}\)推导出各时刻的噪声图像\(\textcolor{blue}{x_t}\)

8.损失函数\(L(x)\)

  假设 Alice 一共向 Bob 传输了\(\textcolor{blue}{n}\)条消息(噪声数据):\(\textcolor{blue}{ \bf y_{1}, . . . , y_{n}}\) , 最后再将原始数据\(\textcolor{blue}{x}\)传过去, 则 loss 函数包含两部分:一部分是传输\(\textcolor{blue}{n}\)条消息的所需的奈特数(nats), 记为\(\textcolor{blue}{L^n(x)}\); 另一部分是最后传输原始数据所需的奈特数, 记为\(\textcolor{blue}{L^r(x)}\) 。

  在 bits-back coding 的编码方案下, \(\textcolor{blue}{L^n(x)}\)等于发送者分布与接收者分布的 KL 散度\(\textcolor{blue}{D_{KL}(p_S|p_R)}\)

  由于有\(\textcolor{blue}{n}\)轮传输, 且每轮对应的\(\textcolor{blue}{\theta_i}\)是随机变量, 具有有概率性, 因此\(\textcolor{blue}{L^n(x)}\)应该是\(\textcolor{blue}{p(\theta_1,...,\theta_{n-1}) = \prod_{i=1}^n p_U(\theta_i |  \theta_{i-1},x;\alpha_i)}\)的期望:

  给定先验参数\(\textcolor{blue}{\theta_0}\)和准确度表\(\textcolor{blue}{β(t)}\),考虑\(\textcolor{blue}{n}\)个发送者样本的序列\(\textcolor{blue}{ \bf y_{1}, . . . , y_{n}}\)\(\textcolor{blue}{  t_{1}, . . . , t_{n}}\)的时候被采样,其中\(\textcolor{blue}{t_i = \frac{i}{n}}\),第\(\textcolor{blue}{i}\)步的发送方分布为\(\textcolor{blue}{p_S(\cdot  |x;\alpha_i)}\),其中

  第\(\textcolor{blue}{i}\)步的接收器分布为\(\textcolor{blue}{p_R(\cdot |\theta_{i-1},t_{i-1};\alpha_i)}\) ,输入参数序列\(\textcolor{blue}{\theta_{1},\theta_{2},...,\theta_{D}}\)是由下式递归计算:

  同理, 在 arithmetic coding 的编码方案下, \(\textcolor{blue}{L^r(x)}\)等价于重构 \(\textcolor{blue}{x}\), 即在最后时刻的参数\(\textcolor{blue}{\theta}\)下的负对数似然\(\textcolor{blue}{-\ln(p_O(x|\theta,1))}\) :

  注意本文中\(\textcolor{blue}{L^r(x)}\)并没有直接优化;然而,它是通过优化\(\textcolor{blue}{L^n(x)}\)来间接训练的,因为两者都是通过将输出分布与数据匹配来最小化的。此外,只要 \(\textcolor{blue}{β(1)}\)足够高,\(\textcolor{blue}{t=1}\)时的输入分布就会非常接近\(\textcolor{blue}{x}\),使得网络拟合 \(\textcolor{blue}{p_O(x|\theta;1)}\)变得微不足道。

  损失函数L(x)定义为传输数据所需的nats总数,即n步损失和重构损失之和:

  L(x) 可以看作VAE去理解:

  不妨将 Alice 传输的消息: \(\textcolor{blue}{ \bf y_{1}, . . . , y_{n}}\)看作是隐变量序列,将隐变量的后验分布\(\textcolor{blue}{q_\phi(z|x)}\)等同于 BFN 的发送者分布:

  既然先验是\(\textcolor{blue}{\theta}\) , 那么相应地, 隐变量的先验分布就选取接收者分布:

  然后, VAE 解码重构的概率分布就用输出分布来充当:

  于是,重构损失就是对应VAE的负对数似然\(\textcolor{blue}{-\ln(p_\theta(x|z))}\)

  最后,loss 函数就是发送者分布和接收者分布的 KL 散度加上在输出分布下的负对数似然:

9.离散时间损失\(L^n(x)\)

  在离散时间步和连续时间的情况下, \(\textcolor{blue}{L^n(x)}\)的形式会有所不同; 而\(\textcolor{blue}{L^r(x)}\)由于只发生在最后一步, 因此没有影响。

  根据蒙特卡洛采样, 我们可以将\(\textcolor{blue}{L^n(x)}\)中的\(\textcolor{blue}{n}\)步求和改写一下, 近似为:

  其中\(\textcolor{blue}{U\{1,n\}}\)是从 1 到 n 的整数上的均匀分布。

  结合精度的可加性, 对于贝叶斯更新分布的那部分期望可以转换为用贝叶斯流分布来表示:

  于是最终得到:

  这样就能通过蒙特卡洛采样近似\(\textcolor{blue}{L^n(x)}\),而无需计算\(\textcolor{blue}{n}\)步总和。

  

10.连续时间损失\(L^\infty (x)\)

  现在来推导下\(\textcolor{blue}{L^n(x)}\)在连续时间情况下的形式,当\(\textcolor{blue}{n \rightarrow \infty}\)时,假设:

  于是,结合上一节在离散时间步情况下的推导结论,有:

  其中 \(\textcolor{blue}{U(\epsilon,1)}\)是区间\(\textcolor{blue}{[\epsilon,1]}\)上的连续均匀分布。

  进一步,作者在 paper 中提出了对于发送者分布和接收者分布之间的 KL 散度的泛化形式:

1697511140145

  其中 \(\textcolor{blue}{g : X → Y}\) 是将原始数据所在空间映射到 Alice 传输的消息(即观测样本)空间的函数,\(\textcolor{blue}{P^{(d)}(θ, t)}\) 定义在观测样本空间上的具有有限期望和方差的单变量分布,* 表示两个概率分布的卷积,\(\textcolor{blue}{C}\)是标量常数。

  当\(\textcolor{blue}{\sigma^2 \rightarrow \infty}\)时, 对于具有有限期望\(\textcolor{blue}{E[P]}\)与方差\(\textcolor{blue}{Var[P]}\) 的连续型单变量概率分布\(\textcolor{blue}{P}\) , 作者还进一步证明了有以下结论:

  这样,就能够将发送者分布和接收者分布之间的 KL 散度转变为两个正态分布之间的 KL 散度,从而方便计算。

  根据统计学中的卷积定义,两个相互独立的随机变量之和 所服从的分布 就是 它们这两个分布的卷积:

\[\Large X \sim P,Y \sim N(0,\sigma^2) \Rightarrow Z = X + Y \sim P * N(0,\sigma^2)\]

  在这基础上, 如果我们能构造出\(\textcolor{blue}{Z=E[P]+R,R \rightarrow N(0,\sigma^2)}\)的形式,那么根据正态分布的性质, 就可以顺理成章地得到目标结论:

\[\Large Z \rightarrow N(E[P],\sigma^2) \Rightarrow P * N(0,\sigma^2) \rightarrow N(E[P],\sigma^2)\]

  下面就推导该结论

  构造一个随机变量序列:\(\textcolor{blue}{X_{0},X_{1},...,X_{n}}\),其中

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

11.采样生成

  1️⃣给定先验参数\(\textcolor{blue}{\theta_0}\),假定我们计划要将其更新n次后才拿来生成样本,那么先设定好对应的每一步精度\(\textcolor{blue}{\alpha_1,\alpha_2,...,\alpha_n}\)和相应的时间\(\textcolor{blue}{t_i = \frac{i}{n}}\)

  2️⃣经过训练后,由于接收者分布已经近乎完美的拟合发送者分布,因此每一步,就可以通过先从输出分布\(\textcolor{blue}{p_O (\cdot|θ_{i-1},t_{i-1})}\)采样\(\textcolor{blue}{\bf{x^\prime}}\),再从\(\textcolor{blue}{p_S(\cdot |\bf{x^\prime};\alpha_i)}\) 采样 y(意味着直接从接收者分布\(\textcolor{blue}{p_R (\cdot|θ_{i-1};t_{i-1},\alpha_i)}\)采样出y ),然后再利用贝叶斯更新:\(\textcolor{blue}{\theta_i = h(\theta_{i-1},y)}\),n 步采样过程递归更新 \(\textcolor{blue}{\theta_1,\theta_2,...,\theta_n}\)

  3️⃣待n次更新完毕,得到\(\textcolor{blue}{\theta_n}\),使用网络再运行一次得到输出分布\(\textcolor{blue}{p_O(\cdot|\theta_n,t=1)}\),再从该分布中采样生成样本\(\textcolor{blue}{\bf{x}}\)

五、连续数据

1.输入分布

  对于连续数据 \(\textcolor{blue}{\bf{x} \in R^D}\),它的输入分布是正态分布:

  则先验参数设置为标准正态分布:

2.发送者分布

  对于连续数据来说,给定\(\textcolor{blue}{\bf{x}}\) \(\textcolor{blue}{\in \mathcal{X}^D}\)以及精度\(\textcolor{blue}{\alpha}\),那么得到的观测样本\(\textcolor{blue}{\bf{y}}\) \(\textcolor{blue}{\in \mathcal{Y}^D}\),发送者分布就定义为:

  这可以看作发送者分布是以\(\textcolor{blue}{\bf{x}}\)为均值,精度\(\textcolor{blue}{\alpha}\)为方差的正态分布

3.贝叶斯更新函数

  在《Conjugate Bayesian analysis of the Gaussian distribution》文章当中,已经证明了这样的一个关系:当给定未知数据\(\textcolor{blue}{x}\)上的单变量高斯先验分布\(\textcolor{blue}{\mathcal{N}(\mu_a,\rho_a^{-1})}\)以及精度\(\textcolor{blue}{\alpha}\),在从正态分布\(\textcolor{blue}{\mathcal{N}(x,\alpha^{-1})}\)采样得到的观测样本\(\textcolor{blue}{y}\)之后,我们可以得到贝叶斯后验分布表示为\(\textcolor{blue}{\mathcal{N}(\mu_b,\rho_b^{-1})}\),其中:

  作者提出由于\(\textcolor{blue}{p_I(x|\theta)}\)\(\textcolor{blue}{p_S(y|x;\alpha)}\)分布都是具有对角协方差的正态分布,因此上面的式子可以应用于当给定输入分布参数\(\textcolor{blue}{\theta_{i-1} = \{\mu_{i-1},\rho_{i-1}\}}\)以及从发送者分布\(\textcolor{blue}{p_S(\cdot |x;\alpha I)=\mathcal{N}(x,\alpha^{-1}I)}\)采样得到的观测样本y,那么根据贝叶斯更新函数,我们可以得到:

  下图展示了连续数据的贝叶斯更新的过程。对于单变量数据 \(\textcolor{blue}{x = 0.7}\),初始输入分布参数\(\textcolor{blue}{\theta_{0} = \{\mu_{0}=0,\rho_{0}=1\}}\)更新为\(\textcolor{blue}{\theta_{1} = \{\mu_{1},\rho_{1}\}}\)\(\textcolor{blue}{\theta_{2} = \{\mu_{2},\rho_{2}\}}\) ,\(\textcolor{blue}{\theta_{3} = \{\mu_{3},\rho_{3}\}}\)发送者分布依据式49和式50,分别以精度2、4、6绘制观测样本\(\textcolor{blue}{y1、y2、y3}\)。我们可以看到输入均值\(\textcolor{blue}{ (μ1、μ2、μ3) }\)是如何一步步随机接近数据,同时输入精度平滑增加。

4.贝叶斯更新分布

  我们知道原先的贝叶斯更新分布是通过边缘化\(\textcolor{blue}{\bf{y} \sim \mathcal{N}(\bf{y}|x,\alpha^{-1}I)}\)所定义的:

  根据式50,我们可以把它写成这样的形式:

\[\huge \mu_i = \frac{\alpha}{\rho_i}\bf{y}+\frac{\mu_{i-1}\rho_{i-1}}{\rho_{i}}\]

  由正态分布的标准恒等式:

  我们可以得到

\[\huge \bf{y} \sim \mathcal{N}(\bf{y}|x,\alpha^{-1}I) \Rightarrow \frac{\alpha}{\rho_i}\bf{y}+\frac{\mu_{i-1}\rho_{i-1}}{\rho_{i}} \sim \mathcal{N}(\frac{\alpha x+\mu_{i-1}\rho_{i-1}}{\rho_{i}},\frac{\alpha}{\rho_{i}^2}I)\]

  从而可以得出:

  因此(\(\textcolor{blue}{μi}\)\(\textcolor{blue}{θi}\) 的唯一随机部分)

  下图展示了连续数据的贝叶斯更新分布。对于 \(\textcolor{blue}{x = 0.7}\),该图显示了方程52中输入平均值\(\textcolor{blue}{μ}\)的分布\(\textcolor{blue}{p(\mu|\theta_0,x;\alpha)}\)。 给定初始参数\(\textcolor{blue}{μ_0 = 0、ρ_0 = 1}\)和 11 个在 \(\textcolor{blue}{e^{−5}}\)\(\textcolor{blue}{e^5}\) 之间以对数线性间隔的\(\textcolor{blue}{\alpha}\)值。我们可以看到对于非常低的 alpha,分布是紧密集中在\(\textcolor{blue}{\mu_0}\)周围,然后平滑地进展到在\(\textcolor{blue}{x}\)周围紧密集中,以获得高 alpha。

5.精度可加性

  如果我们能从\(\textcolor{blue}{p(\cdot|\theta_{i-2},\bf x;\alpha_a)}\)得到\(\textcolor{blue}{\theta_{i-1} = \{\mu_{i-2},\rho_{i-2}\}}\),那么就有:

  定义:

  我们再一次借用正态分布的标准恒等式,其中\(\textcolor{blue}{a=\frac{\rho_{i-1}}{\rho_{i}}}\)\(\textcolor{blue}{b=\frac{\alpha_b \bf x}{\rho_{i}}}\)

  我们假设\(\textcolor{blue}{\theta_{i} = \{\mu_{i-1},\rho_{i-1}\}}\)能够从\(\textcolor{blue}{p(\cdot|\theta_{i-1},\bf x;\alpha_b)}\)得到,那么

  因此:

  其中:

  在这里应用高斯变量的另一个标准恒等式:

  我们可以得到:

  由此,我们可以得到(具体待证明)

6.精度表\(\beta(t)\)

  我们通过要求输入分布的预期熵随\(\textcolor{blue}{t}\)线性减小来导出连续数据的\(\textcolor{blue}{β(t)}\)。直观上,这意味着信息以恒定速率流入输入分布。定义:

  那么如果\(\textcolor{blue}{H(t)}\)\(\textcolor{blue}{t}\)线性减小,

  将\(\textcolor{blue}{\sigma_1}\)定义为\(\textcolor{blue}{t=1}\)时输入分布的标准差。我们将根据经验选择\(\textcolor{blue}{\sigma_1}\)以最小化损失;一般来说,它应该足够小以确保重建损失低,但又不能小到产生不必要的传输成本。回想一下\(\textcolor{blue}{t}\)时刻的精度\(\textcolor{blue}{\rho}\)\(\textcolor{blue}{1+\beta(t)}\),我们看到

  因此:

7.贝叶斯流分布\(p_F(\cdot | x;t)\)

  回到之前的贝叶斯流分布概率函数:

  因此,在等式中设置\(\textcolor{blue}{\theta_{i-1} = \theta_0 = \{0,1\}}\)\(\textcolor{blue}{α = β(t)}\)并由式53可以得到下面的结论:其中设置\(\textcolor{blue}{ρ_i = 1 + β(t)}\)

  其中,

8.输出分布

  作者遵循扩散模型的标准,输出分布是通过重新参数化高斯噪声向量\(\textcolor{blue}{ε ∼ N (0, I)}\)的预测来定义的,该向量用于生成作为网络输入传递的均值\(\textcolor{blue}{μ}\)。根据式77:

  依据重新参数化,我们可以得到:

  将上式的\(\textcolor{blue}{ε}\)代入网络输出的估计\(\textcolor{blue}{\hat{ε}(θ, t)}\),并将其转换为\(\textcolor{blue}{x}\)的估计\(\textcolor{blue}{\hat{x}(θ, t)}\)

  给定\(\textcolor{blue}{\hat{x}(θ, t)}\),输出分布为

9.接收者分布

  依据下式:

\[\huge p_O(\bf x|\theta;t) = δ(x-\hat{x}(\theta,t))\]

\[\huge p_S(\bf y|x;\alpha)=\mathcal{N}(y|x,\alpha^{-1}I)\]

\[\huge p_R(\bf y | \theta; t, \alpha) = E_{ p_O (x^′|\theta;t)} p_S( y | x^′; \alpha)\]

  可以推出:

10.重建损失\(L^r(x)\)

  真正连续的数据需要无限的精度来重建,这使得重建损失成为问题。然而,可以合理地假设数据要么被精细离散,要么包含一些噪声。或者,如果我们假设 x 上存在正态分布的测量噪声,且具有固定的各向同性方差 σ2,则重建损失的噪声版本可以定义为\(\textcolor{blue}{N(x、σ^2I)}\)\(\textcolor{blue}{t=1}\)时的输出分布之间的预期 KL 散度:

  噪声不会直接影响训练,因为重建损失没有优化。然而,σ 的值对应为 σ1 选择的值设置了一个自然上限:没有必要将数据传输到比最初测量的精度更高的精度。根据经验,我们发现当 σ1 < σ/2 时,重建损失非常小。

11.离散时间损失\(L^n(x)\)

12.连续时间损失\(L^\infty (x)\)

13.伪代码

  1️⃣给定输入参数\(\textcolor{blue}{\theta = \{\mu,\rho\}}\),其中\(\textcolor{blue}{\rho}\)完全由\(\textcolor{blue}{t}\)决定。下面是定义了一个对于连续数据网络输出的参数进行重新参数化采样样本的函数。

  2️⃣算法1是计算连续数据的\(\textcolor{blue}{n}\)步损失\(\textcolor{blue}{L^n(x)}\),其中\(\textcolor{blue}{n}\)步传输,\(\textcolor{blue}{\sigma_1}\)都是预先设定好的超参数,由此也能够进一步得出\(\textcolor{blue}{t}\)以及\(\textcolor{blue}{\gamma}\),把\(\textcolor{blue}{\gamma}\)按照式77进一步计算出均值\(\textcolor{blue}{\mu}\)。接下来把\(\textcolor{blue}{\mu,t,\gamma}\)送入网络中计算参数,返回采样得到的估计\(\textcolor{blue}{\hat{x}(θ, t)}\),最后将得到的网络估计样本与原始数据进行损失计算。

  3️⃣算法2是计算连续时间的连续时间损失\(\textcolor{blue}{L^{\infty}(x)}\),具体做法和算法1类似,只是不需要\(\textcolor{blue}{n}\)步传输作为超参数。

  4️⃣算法 3 中给出了样本生成过程,其中\(\textcolor{blue}{n}\)步传输,\(\textcolor{blue}{\sigma_1}\)都是预先设定好的超参数,预先设定的先验输入参数为\(\textcolor{blue}{μ_0 = 0、ρ_0 = 1}\),在第\(\textcolor{blue}{i}\)步当中我们用已经训练好的网络估计出\(\textcolor{blue}{\hat{x}(θ, t)}\),并计算在第\(\textcolor{blue}{i}\)步下的精度\(\textcolor{blue}{\alpha}\)。然后根据接收者分布采样出观测样本\(\textcolor{blue}{y}\),利用贝叶斯更新函数对\(\textcolor{blue}{μ、ρ}\)进行更新,重复上述过程一直到n步。最后得到进行n步更新好的\(\textcolor{blue}{\mu}\)送入到网络中得到最终的生成样本。

六、离散数据

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

  

七、离散化数据

八、实验

  作者根据以下生成基准评估了贝叶斯流网络 (BFN):CIFAR-10、动态二值化 MNIST和 text8(长度为 256 个字符序列,大小 27 字母)。网络都使用连续时间损失\(\textcolor{blue}{L^\infty (x)}\)进行训练,离散时间损失\(\textcolor{blue}{L^n (x)}\)仅用于测试,并且具有不同的\(\textcolor{blue}{n}\)值。整个过程中使用了标准网络架构和训练算法,以便与现有方法进行直接比较。由于本文的重点是概率建模而不是图像生成,因此未计算 FID 分数。但是为所有实验都提供了生成数据的示例。

1.动态二值化 MNIST

  设置:该网络架构基于为扩散模型引入的 U-Net 。实验使用 600 张随机选择的训练图像(训练集的 1%)用作验证集。模型可学习参数总数约为25M。

  结果:

  1️⃣图 12:MNIST 真实数据和生成数据。经过 100 个步骤生成的样本。

  2️⃣下图分别展示了MNIST 输入和输出分布。对于两个测试集图像,该图显示了 t = 0 和 t = 1/3 之间均匀间隔的 20 个步骤的白色像素概率。我们能够注意输入概率最初是均匀的,而输出分布最初预测多个数字的叠加,与训练集上的每像素边缘先验紧密匹配。另请注意,输出分布的噪声比输入分布小得多,并且随着接收到新信息,输出分布的变化更加显着。这凸显了网络使用上下文来解决输入分布中的歧义和噪声。

2.CIFAR-10

  设置:网络架构本质上与变分扩散模型(VDM )所使用的相同。总共大约有 3100 万个可学习参数。所有 CIFAR-10 实验均使用由 500 个随机选择的训练图像组成的验证集(训练集的 1%)来进行。

  结果:

  1️⃣表 1 显示,性能最佳的 BFN 为 256 个 bin 数据提供了 2.66 BPD,这接近最先进的 2.64 BPD。最明显的性能基准(考虑到共享网络架构和损失函数的相似性)是 2.65 BPD 的 VDM 结果。然而,这需要 10M 权重更新才能实现,并且由于时间限制,作者只训练 BFN 进行 5M 权重更新。 5M权重更新后验证性能仍在提高,10M 更新后性能会提高多少仍不清楚。

BPD(bits-per-dimension,越低越好),它是计算每个维度的比特数上的负对数似然。计算以 e 为基数的负对数似然,应用基数的变化将以对数为基数 e 转换为以 2 为基数,然后除以像素数。

  2️⃣下图显示CIFAR-10 真实数据和生成数据。使用经过离散损失训练的网络,通过 4,000 个步骤生成样本。

  3️⃣下图展示了CIFAR-10 输入和输出分布。对于两个测试集图像,该图显示了在 t = 0 和 t = 0.25 之间均匀间隔的输入和输出分布的图像变换。

3.text8

  数据:text8 数据集源自 enwik9 维基百科数据集的子集,通过删除标点符号并将文本限制为小写拉丁字母和空格,给出大小为 27 的字母表。为了清楚起见,我们在图中用下划线表示空格字符

  设置:网络架构是一个类似于 Radford 等人使用的小型模型 Transformer。 不同之处在于它使用GELU激活函数并且深度增加到24层。 Transformer 的输入和输出被连接起来,然后投影回输出大小以产生最终输出。使用 90M/5M/5M 连续字符的标准训练/验证/测试分割。模型可学习参数总数约为170M。

  结果

  1️⃣表 4 显示,BFN 在 text8 测试集上产生了 1.41 BPC,这比目前在文献中发现的所有离散扩散模型都要好,并且接近最佳阶不可知模型 MAC(1.40 BPC)。

  2️⃣然而我们注意到,标准自回归基线和离散流模型在 1.23 BPC 下的表现都要好得多。表 5 显示,对于 n 的减少,性能相当稳健,只需 100 个步骤即可达到 1.43 BPC。通过离散时间损失训练可能会改善这个结果。

BPC:当计算基于字符长度单位的混淆度 (Perplexity)时,\(Perplexity = 2^{BPC}\)

  3️⃣下图展示的是text8输入和输出分布的另一种可视化,其中字符大小按其概率成比例缩放。

九、总结

  

  

  

  

  本文参考文章:

  BTNs是怎么玩转生成即压缩的?详解结合贝叶斯统计和深度学习的生成模型 — Bayesian Flow Networks(一)

  贝叶斯流网络

CATALOG
  1. 1. 一、简要
  2. 2. 二、原理介绍
    1. 2.1. 1.现有模型的建模方式
    2. 2.2. 2.BFNs的建模方式
    3. 2.3. 3.输入分布与输出分布的作用
  3. 3. 三、BFNs与扩散模型
  4. 4. 四、贝叶斯流网络(BFNs)
    1. 4.1. 1.输入分布
    2. 4.2. 2.发送者分布(sender distribution)
    3. 4.3. 3.输出分布
    4. 4.4. 4.接收者分布(receiver distribution)
    5. 4.5. 5.贝叶斯更新
    6. 4.6. 6.精度表\(\beta(t)\)
    7. 4.7. 7.贝叶斯流分布\(p_F(\cdot | x;t)\)
    8. 4.8. 8.损失函数\(L(x)\)
    9. 4.9. 9.离散时间损失\(L^n(x)\)
    10. 4.10. 10.连续时间损失\(L^\infty (x)\)
    11. 4.11. 11.采样生成
  5. 5. 五、连续数据
    1. 5.1. 1.输入分布
    2. 5.2. 2.发送者分布
    3. 5.3. 3.贝叶斯更新函数
    4. 5.4. 4.贝叶斯更新分布
    5. 5.5. 5.精度可加性
    6. 5.6. 6.精度表\(\beta(t)\)
    7. 5.7. 7.贝叶斯流分布\(p_F(\cdot | x;t)\)
    8. 5.8. 8.输出分布
    9. 5.9. 9.接收者分布
    10. 5.10. 10.重建损失\(L^r(x)\)
    11. 5.11. 11.离散时间损失\(L^n(x)\)
    12. 5.12. 12.连续时间损失\(L^\infty (x)\)
    13. 5.13. 13.伪代码
  6. 6. 六、离散数据
  7. 7. 七、离散化数据
  8. 8. 八、实验
    1. 8.1. 1.动态二值化 MNIST
    2. 8.2. 2.CIFAR-10
    3. 8.3. 3.text8
  9. 9. 九、总结