[半监督学习] Interpolation consistency training for semi-supervised learning

[半监督学习] Interpolation consistency training for semi-supervised learning提出插值一致性训练 Interpolatio consistency training ICT 这是一种简单且效率高的算法 用于在半监督学习范式中训练深度神经网络 在分类问题中 ICT 将决策边界移动到数据分布的低密度区域 论文地址 Interpolatio consistency training for semi supervised learning 算法代码

大家好,我是讯享网,很高兴认识大家。

提出插值一致性训练(Interpolation consistency training, ICT), 这是一种简单且效率高的算法, 用于在半监督学习范式中训练深度神经网络. 在分类问题中, ICT 将决策边界移动到数据分布的低密度区域.

论文地址: Interpolation consistency training for semi-supervised learning
算法代码: https://github.com/vikasverma1077/ICT
会议: IJCAI 2019
任务: 分类

低密度分离假设(聚类假设)启发了许多一致性正则化半监督学习技术, 包括 Π \Pi Π-model, Temporal ensembling, Mean-Teacher, VAT. 另外, 在 ICT 这篇论文之后又出现了 UDA, 其效果比 ICT 还要好.

有研究表明: 对抗性扰动训练会损害泛化性能. 为了克服这个问题, 便提出了插值一致性训练(ICT), 简单来说, ICT 通过在未标记点 u 1 u_1 u1 , u 2 u_2 u2 的插值 α u 1 + ( 1 − α u 2 ) \alpha u_1 + (1-\alpha u_2) αu1+(1αu2) 上的一致性预测 f ( α u 1 + ( 1 − α u 2 ) ) = α f ( u 1 ) + ( 1 − α ) f ( u 2 ) f(\alpha u_1 + (1-\alpha u_2))=\alpha f(u_1)+(1-\alpha)f(u_2) f(αu1+(1αu2))=αf(u1)+(1α)f(u2) 来规范半监督学习.

与监督学习相比, ICT 的决策边界穿越低密度区域, 这将更好地反映未标记数据的结构. 对比结果如下图所示:
在这里插入图片描述
讯享网

插值一致性训练(ICT)

根据论文 Mixup: Beyond empirical risk minimization. In International conference on learning
representations. 给出 mixup 式子:
M i x λ ( a , b ) = λ a + ( 1 − λ ) b (1) {\rm Mix}_{\lambda}(a,b)=\lambda a + (1-\lambda)b \tag{1} Mixλ(a,b)=λa+(1λ)b(1)
ICT 训练分类器 f θ f_\theta fθ 以在未标记点的插值中提供一致性预测:
f θ ( M i x λ ( u j , u k ) ) ≈ M i x λ ( f θ ′ ( u j ) , f θ ′ ( u k ) ) (2) f_\theta({\rm Mix}_{\lambda}(u_j,u_k))\approx {\rm Mix}_\lambda(f_{\theta'}(u_j),f_{\theta'}(u_k)) \tag{2} fθ(Mixλ(uj,uk))Mixλ(fθ(uj),fθ(uk))(2)
其中 θ ′ \theta' θ θ \theta θ 的滑动平均. 那么为什么未标记样本之间的插值为半监督训练提供了良好的一致性扰动呢? 有如下解释.

应该应用一致性正则化的最有用的样本是靠近决策边界的样本. 向这种低边距未标记样本 u j u_j uj 添加一个小的扰动 δ \delta δ 可能会将 u j + δ u_j + \delta uj+δ 推到决策边界的另一侧. 这将违反低密度分离假设, 就使 u j + δ u_j + \delta uj+δ 成为应用一致性正则化的好位置.

回到低边距未标记点 u j u_j uj, 如何找到一个扰动 δ \delta δ, 使得 u j u_j uj u j + δ u_j +\delta uj+δ 位于决策边界的相对两侧? 而使用随机扰动是一种低效的策略, 因为接近决策边界的方向子集只是环境空间的一小部分. 那么, 可以考虑对第二个随机选择的未标记样本 u k u_k uk 进行插值 u j + δ = M i x λ ( u j , u k ) u_j + \delta = {\rm Mix}_\lambda(u_j , u_k) uj+δ=Mixλ(uj,uk). 而两个未标记的样本 u j u_j uj u k u_k uk 存在下面三种情况:

  • (1) 位于同一蔟中
  • (2) 位于不同的蔟中, 但属于同一类别
  • (3) 位于不同的蔟上, 属于不同的类别

由聚类假设, (1)的概率随类别数的增加而降低. 如果假设每个类别的聚类数是平衡的, 则(2)的概率较低;最后, (3)的概率最高. 然后, 假设 ( u j , u k ) (u_j,u_k) (uj,uk) 之一位于决策边界附近(它是执行一致性的一个很好的候选者), 则由于(3)的概率很高, 朝向 u k u_k uk 的插值可能指向低密度区域, 其次是另一类的聚类. 由于这是移动决策的不错的方向, 因此对于基于一致性的正则化, 插值是一个很好的扰动.

到目前为止, 随机未标记样本之间的插值可能会落在低密度区域, 因此, 这种插值可以应用基于一致性的正则化. 但是应该如何标记这些插值呢? 与单个未标记样本 u j u_j uj 的随机或对抗性扰动不同, ICT 涉及两个未标记示例 ( u j , u k ) (u_j,u_k) (uj,uk). 直观地说, 我们希望将决策边界尽可能地推离类别边界, 因为众所周知, 具有大边距的决策边界可以更好地泛化.

在监督学习环境中, mixup 是实现大边距决策边界的一种方法. 在 mixup 中, 通过强制预测模型在样本之间线性变化, 将决策边界推离类别边界, 通过式(2)来完成. 在这里, 通过训练模型 f θ f_\theta fθ 来预测 M i x λ ( u j , u k ) {\rm Mix}_λ(u_j,u_k) Mixλ(uj,uk) 的"假标签" M i x λ ( f θ ′ ( u j ) , f θ ′ ( u k ) ) {\rm Mix}_\lambda (f_{\theta'}(u_j),f_{\theta'}(u_k)) Mixλ(fθ(uj),fθ(uk)) 来将 mixup 扩展到半监督学习. 其中 θ ′ \theta' θ θ \theta θ 的滑动平均, 同 Mean-Teacher 中 Teacher model 的 θ ′ \theta' θ 的计算一样.

ICT 模型如下图所示:
在这里插入图片描述

  • 从联合分布 P X Y ( X , Y ) P_{XY}(X,Y) PXY(X,Y) 提取标记样本 ( x i , y i ) (x_i,y_i) (xi,yi) 记为 D L \mathcal{D}_L DL.
  • 从边缘分布 P X ( X ) = P X Y ( X , Y ) P Y ∣ X ( Y ∣ X ) P_X(X)=\frac{P_{XY}(X,Y)}{P_{Y\vert X}(Y \vert X)} PX(X)=PYX(YX)PXY(X,Y) 提取未标记样本 u j u_j uj, u k u_k uk 记为 D u l \mathcal{D}_{ul} Dul.
  • 学习目标是训练一个模型 f θ f_\theta fθ, 能够从 X X X 预测 Y Y Y. 通过使用随机梯度下降, 在每次迭代 t t t 时, 更新参数 θ \theta θ 以最小化损失函数 L = L S + w ( t ) L U S L=L_S+w(t)L_{US} L=LS+w(t)LUS. 其中 L S L_S LS 为标记样本上的交叉熵损失, L U S L_{US} LUS 为新的插值上的一致性正则化损失. 两个损失都是在 minibatch 上进行计算的, 每次迭代后 w ( t ) w(t) w(t) 都会 ramp up, 以增加 L U S L_{US} LUS 的重要性.

为了计算 L U S L_{US} LUS, 对两个小批量未标记点 u j u_j uj u k u_k uk 进行采样, 并计算它们的假标签 y ^ j = f θ ′ ( u j ) \hat{y}_j=f_{\theta'}(u_j) y^j=fθ(uj) y ^ k = f θ ′ ( u k ) \hat{y}_k=f_{\theta'}(u_k) y^k=fθ(uk), 然后, 计算插值 u m = M i x λ ( u j , u k ) u_m={\rm Mix}_λ(u_j,u_k) um=Mixλ(uj,uk), 以及该位置的模型预测 y ^ m = f θ ′ ( u m ) \hat{y}_m=f_{\theta'}(u_m) y^m=fθ(um). 接着, 更新参数 θ \theta θ 以使预测 y ^ m \hat{y}_m y^m 更接近于假标签的插值 M i x λ ( y ^ j , y ^ k ) {\rm Mix}_λ(\hat{y}_j,\hat{y}_k) Mixλ(y^j,y^k). 预测 y ^ m \hat{y}_m y^m M i x λ ( y ^ j , y ^ k ) {\rm Mix}_λ(\hat{y}_j,\hat{y}_k) Mixλ(y^j,y^k) 之间的差异可以使用任何损失来衡量, 在本文实验中, 使用的是均方误差. 对于式(1), 在每次更新时, 从 B e t a ( α , α ) {\rm Beta}(\alpha,\alpha) Beta(α,α) 中随机抽取一个 λ \lambda λ.

综上, L U L L_{UL} LUL 可被写为:
L U L = E u j , u k E λ l ( f θ ( M i x λ ( u j , u k ) ) , M i x λ ( f θ ′ ( u j ) , f θ ′ ( u k ) ) ) (3) \mathcal{L}_{UL}=\underset{u_j,u_k}{\mathbb{E}}\underset{\lambda}{\mathbb{E}} \mathcal{l}(f_\theta({\rm Mix}_{\lambda}(u_j,u_k)),{\rm Mix}_\lambda(f_{\theta'}(u_j),f_{\theta'}(u_k))) \tag{3} LUL=uj,ukEλEl(fθ(Mixλ(uj,uk)),Mixλ(fθ(uj),fθ(uk)))(3)

ICT 算法流程如下:
在这里插入图片描述

小讯
上一篇 2025-04-06 07:33
下一篇 2025-01-08 10:02

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/15819.html