当前位置:首页> AI教程> 深度学习模型设计与端到端训练

深度学习模型设计与端到端训练

释放双眼,带上耳机,听听看~!
本文介绍了深度学习模型设计中的端到端训练方法,并讨论了embedding函数和神经网络在模型设计中的应用。

模型设计

  • First, we must define an embedding function that maps discrete text into a continuous space.
  • Second, we require a rounding method to map vectors in embedding space back to words.

深度学习模型设计与端到端训练

端到端训练

为了将离散文本应用到连续扩散模型上,这里设置了一个embedding函数EMB(wi)mathrm E_{mathrm{MB}}(w_i)将每个词映射到词向量Rdmathbb{R}^d。对于长度为nn的序列wmathbf w

EMB(w)=[EMB(w1),…,EMB(wn)]∈Rndmathrm E_{mathrm{MB}}(mathbf{w})=left[mathrm E_{mathrm{MB}}left(w_1right), ldots, mathrm E_{mathrm{MB}}left(w_nright)right] in mathbb{R}^{n d}

作者在实验中发现,使用预训练的word embedding效果不如使用随机高斯噪声初始化之后训练出来的embedding,所以加了这样一个网络,实现离散词序列wmathbf wx0mathbf x_0的马尔科夫转换,参数化为:

qϕ(x0∣w)=N(EMB(w),σ0I)q_phileft(mathbf{x}_0 mid mathbf{w}right)=mathcal{N} (mathrm E_{mathrm{MB}} (mathbf{w}), sigma_0 I)

与之对应的反向过程添加了一个可训练的舍入步骤,参数化为:

pθ(w∣x0)=∏i=1npθ(wi∣xi)p_thetaleft(mathbf{w} mid mathbf{x}_0right)=prod_{i=1}^n p_thetaleft(w_i mid x_iright)

其中pθ(wi∣xi)p_thetaleft(w_i mid x_iright)是一个softmax分布。

因为增加了一个嵌入步骤和一个舍入步骤,添加的这两个网络是和扩散模型进行联合训练的,因此要对训练目标函数改进:

Lvlbe2e(w)=Eqϕ(x0∣w)[Lvlb(x0)+log⁡qϕ(x0∣w)−log⁡pθ(w∣x0)]],Lsimple e2e(w)=Eqϕ(x0:T∣w)[Lsimple (x0)+∥EMB(w)−μθ(x1,1)∥2−log⁡pθ(w∣x0)].begin{aligned} mathcal{L}_{mathrm{vlb}}^{mathrm{e2e}}(mathbf{w}) & left.=underset{q_phileft(mathbf{x}_0 mid mathbf{w}right)}{mathbb{E}}left[mathcal{L}_{mathrm{vlb}}left(mathbf{x}_0right)+log q_phileft(mathbf{x}_0 mid mathbf{w}right)-log p_thetaleft(mathbf{w} mid mathbf{x}_0right)right]right],

mathcal{L}_{text {simple }}^{mathrm{e2e}}(mathbf{w}) & =underset{q_phileft(mathbf{x}_{0: T} mid mathbf{w}right)}{mathbb{E}}left[mathcal{L}_{text {simple }}left(mathbf{x}_0right)+left|mathrm E_{mathrm{MB}} (mathbf{w})-mu_thetaleft(mathbf{x}_1, 1right)right|^2-log p_thetaleft(mathbf{w} mid mathbf{x}_0right)right] .end{aligned}

Lsimplee2e(w)mathcal{L}_{mathrm{simple}}^{mathrm{e} 2 mathrm{e}}(mathbf{w})Lvlbe2e(w)mathcal{L}_{mathrm{vlb}}^{mathrm{e} 2 mathrm{e}}(mathbf{w})的简化版,是根据DDPM设计的目标函数。

DDPM的目标函数:

Lvlb(x0)=Eq(x1:T∣x0)[log⁡q(xT∣x0)pθ(xT)+∑t=2Tlog⁡q(xt−1∣x0,xt)pθ(xt−1∣xt)−log⁡pθ(x0∣x1)]mathcal{L}_{mathrm{vlb}}left(mathbf{x}_0right)=underset{qleft(mathbf{x}_{1: T} mid mathbf{x}_0right)}{mathbb{E}}left[log frac{qleft(mathbf{x}_T mid mathbf{x}_0right)}{p_thetaleft(mathbf{x}_Tright)}+sum_{t=2}^T log frac{qleft(mathrm{x}_{t-1} mid mathbf{x}_0, mathbf{x}_tright)}{p_thetaleft(mathbf{x}_{t-1} mid mathbf{x}_tright)}-log p_thetaleft(mathbf{x}_0 mid mathbf{x}_1right)right]
Lsimple (x0)=∑t=1TEq(xt∣x0)∥μθ(xt,t)−μ^(xt,x0)∥2mathcal{L}_{text {simple }}left(mathbf{x}_0right)=sum_{t=1}^T underset{qleft(mathbf{x}_t mid mathbf{x}_0right)}{mathbb{E}}left|mu_thetaleft(mathbf{x}_t, tright)-hat{mu}left(mathbf{x}_t, mathbf{x}_0right)right|^2

可视化之后可以看到学到的embedding是有意义的。

深度学习模型设计与端到端训练

减少舍入误差

embedding是将离散文本映射到连续的x0mathbf x_0上,那与之对应的反向过程就应该是将模型预测出来的x0mathbf x_0转换回离散的文本。

舍入步骤是根据argmax⁡pθ(w∣x0)=∏i=1npθ(wi∣xi)operatorname{argmax} p_thetaleft(mathbf{w} mid mathbf{x}_0right)=prod_{i=1}^n p_thetaleft(w_i mid x_iright)选择每个位置上可能性最大的词。理想状态下通过这个舍入步骤就可以将模型输出的x0mathbf x_0映射回离散文本,因为去噪步骤应该能让x0mathbf x_0恰好回到某个单词的embedding上。

然而实际情况是不行的,模型输出是不会精确到某个单词的embedding。

作者认为造成上述问题的原因,是在目标函数中Lsimplee2e(x0)mathcal{L}_{mathrm{simple}}^{mathrm{e} 2 mathrm{e}}(mathbf{mathbf x_0})x0mathbf x_0结构的建模不够重视。

Lsimple (x0)=∑t=1TExt∥μθ(xt,t)−μ^(xt,x0)∥2mathcal{L}_{text {simple }}(mathbf{x_0}) = sum_{t=1}^T mathbb{E}_{mathbf{x}_t}left|mu_thetaleft(mathrm{x}_t, tright)-hat{mu}left(mathrm{x}_t, mathbf{x}_0right)right|^2

其中的μθ(xt,t)mu_thetaleft(mathrm{x}_t, tright)网络直接去预测时间步tt去噪的pθ(xt−1∣xt)p_thetaleft(mathbf{x}_{t-1} mid mathbf{x}_tright)的均值。x0mathbf x_0对单词的约束只会出现在tt接近于0的项中。因此需要对目标函数进行调整,去强调x0mathbf x_0

作者对Lsimple mathcal{L}_{text {simple }}进行调整,强调模型在目标函数的每一项中都去显式地建模x0mathbf x_0
作者推导出一个类似于Lsimple mathcal{L}_{text {simple }}的使用x0mathbf x_0参数化的公式:

Lx0-simple e2e(x0)=∑t=1TExt∥fθ(xt,t)−x0∥2mathcal{L}_{mathbf{x}_0 text {-simple }}^{mathrm{e2e}}left(mathbf{x}_0right)=sum_{t=1}^T mathbb{E}_{mathbf{x}_t}left|f_thetaleft(mathbf{x}_t, tright)-mathbf{x}_0right|^2

其中网络fθ(xt,t)f_thetaleft(mathbf{x}_t, tright)直接去预测x0mathbf x_0,这样会让神经网络去预测每一项的x0mathbf x_0
使用修改之后的目标去训练模型,模型很快就会学到让x0mathbf x_0能回到以word embedding为中心。

clamping trick

作者将其称为clamping trick,在clamping trick中,模型将xtmathbf x_t降噪为xt−1mathbf x_{t-1}的生成过程:

  1. 通过fθ(xt,t)f_theta (mathbf x_t,t)估计出一个x0mathbf x_0

  2. 在这个估计的条件相爱对xt−1mathbf x_{t-1}进行采样

  3. xt−1=αˉfθ(xt,t)+1−αˉϵmathbf{x}_{t-1}=sqrt{bar{alpha}} f_thetaleft(mathbf{x}_t, tright)+sqrt{1-bar{alpha}} epsilon
    其中αˉt=∏s=0t(1−βs)bar{alpha}_t=prod_{s=0}^tleft(1-beta_sright)ϵ∼N(0,I)epsilon sim mathcal{N}(0, I)

    这一步就是用的DDPM中的那个,因为都是高斯核,所以xtmathbf x_t可以由x0mathbf x_0一步得到。

    xt=αtxt−1+1−αtϵt−1∗=αt(αt−1xt−2+1−αt−1ϵt−2∗)+1−αtϵt−1∗=αtαt−1xt−2+αt−αtαt−1ϵt−2∗+1−αtϵt−1∗=αtαt−1xt−2+αt−αtαt−12+1−αt2ϵt−2=αtαt−1xt−2+αt−αtαt−1+1−αtϵt−2=αtαt−1xt−2+1−αtαt−1ϵt−2=…=∏i=1tαix0+1−∏i=1tαiϵ0=αˉtx0+1−αˉtϵ0begin{aligned} boldsymbol{x}_t & =sqrt{alpha_t} x_{t-1}+sqrt{1-alpha_t} epsilon_{t-1}^* & =sqrt{alpha_t}left(sqrt{alpha_{t-1}} x_{t-2}+sqrt{1-alpha_{t-1}} epsilon_{t-2}^*right)+sqrt{1-alpha_t} epsilon_{t-1}^* & =sqrt{alpha_t alpha_{t-1}} x_{t-2}+sqrt{alpha_t-alpha_t alpha_{t-1}} epsilon_{t-2}^*+sqrt{1-alpha_t} epsilon_{t-1}^* & =sqrt{alpha_t alpha_{t-1}} x_{t-2}+sqrt{{sqrt{alpha_t-alpha_t alpha_{t-1}}}^2+sqrt{1-alpha_t^2}} epsilon_{t-2} & =sqrt{alpha_t alpha_{t-1}} x_{t-2}+sqrt{alpha_t-alpha_t alpha_{t-1}+1-alpha_t} epsilon_{t-2} & =sqrt{alpha_t alpha_{t-1}} x_{t-2}+sqrt{1-alpha_t alpha_{t-1}} epsilon_{t-2} & =ldots & =sqrt{prod_{i=1}^t alpha_i} x_0+sqrt{1-prod_{i=1}^t alpha_i} epsilon_0 & =sqrt{bar{alpha}_t} x_0+sqrt{1-bar{alpha}_t} epsilon_0end{aligned}

clamping trick 会将网络fθ(xt,t)f_theta (mathbf x_t,t)的预测结果映射到接近的word embedding 序列上。
现在采样步骤就变为了:

xt−1=αˉ⋅Clamp⁡(fθ(xt,t))+1−αˉϵmathbf{x}_{t-1}=sqrt{bar{alpha}} cdot operatorname{Clamp}left(f_thetaleft(mathbf{x}_t, tright)right)+sqrt{1-bar{alpha}} epsilon

clamping trick 迫使扩散模型降噪过程中每一步都去计算一个word embedding,使向量预测更为准确,以此减少舍入误差。

作者在这里提示将开始使用clamping trick的起始位置设置为超参数。具体原因看论文P5

论文信息

深度学习模型设计与端到端训练

论文地址:[2205.14217] Diffusion-LM Improves Controllable Text Generation (arxiv.org)

代码地址:XiangLi1999/Diffusion-LM: Diffusion-LM (github.com)

本文正在参加「金石计划」

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

原神介绍:探秘提瓦特的幻想世界

2023-12-15 9:37:14

AI教程

学术机构命名实体归一化系统sCool详解

2023-12-15 9:43:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索