2025年虚拟对抗训练(VAT)原理和代码解析

虚拟对抗训练(VAT)原理和代码解析虚拟对抗训练 VAT 原理和代码解析 微软在 ACL20 发表了一篇 A dversarial Training for Large Neural L angu age M odels 对应的代码 ALUM 这是一篇首次在大规模语料做对抗式训练的语言模型研究 提出了 ALUM 通用的对抗式训练的算法

大家好,我是讯享网,很高兴认识大家。

虚拟对抗训练(VAT)原理和代码解析

微软在ACL20发表了一篇Adversarial Training for Large Neural Language Models,对应的代码ALUM,这是一篇首次在大规模语料做对抗式训练的语言模型研究,提出了ALUM通用的对抗式训练的算法,并且在当前预训练模型上取得SOTA。此研究目的是解决当前的预训练模型(文中用BERT和ROBERT)泛化性和鲁棒性不足的,并且当前对抗训练虽然可以增强鲁棒性,但会损害泛化性的问题。作者还指出ALUM可以在预训练和下游任务都可以使用。

预备知识

此模型是一种半监督学习的模型,相比于其他对抗式学习不同之处,例如FGSM、FGM、PGD等,对于ALUM是加入了无标签数据去优化模型参数。所以了解其他的对抗学习之后,再看看论文发现原理不会很难,以下列出几点需要提前掌握的知识点。

对抗式学习(FGSM、FGM、PGD等)

强烈推荐看【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现这篇blog。简单说对抗式训练是做防御和攻击的训练过程,即在输入 x x x上加入一个扰动 r r r r r r是利用模型Loss对于 x x x的梯度加上正则化得到,然后利用加入扰动的x进入模型,再进行一次训练。

DL散度Loss

DL散度是量化两种概率分布P和Q之间差异的方式 :
D ( p ∣ ∣ q ) = ∑ p ( x i ) ∗ ( l o g ( p ( x i ) − l o g q ( x i ) ) D(p||q)=\sum p(x_i)*(log(p(x_i)-logq(x_i)) D(pq)=p(xi)(log(p(xi)logq(xi))
在论文中 p p p是实际样本输入预训练模型输出 l o g i t s logits logits q q q是指对抗样本输入预训练模型后输出 a d v _ l o g i t s adv\_logits adv_logits,所以这里得到模型其中的一部分Loss。

p = torch.tensor([[0.7, 0.2, 0.1], [0.2,0.2, 0.6], [0.3, 0.2, 0.5]]) q = torch.tensor([[0.6, 0.3, 0.1], [0.2,0.2, 0.6], [0.3, 0.1, 0.6]]) torch.nn.functional.kl_div(q.log_softmax(dim=-1), p.softmax(dim=-1) 

讯享网

模型过程

论文中给出了具体的算法过程,如下:
在这里插入图片描述
讯享网
大概说一下具体的参数和步骤,首先先说参数:

参数 描述 取值
T epoch -
K 要做多少次扰动更新 理论越多效果越好,但是就expensive,论文中 K=1
∏ \prod 正则化方法 L0,L1,L2中选择
α \alpha α 增强对抗学习的比例 预训练为10,下游任务为1
η \eta η 扰动的学习率 1 × 1 0 − 3 1 × 10^{−3} 1×103
τ \tau τ 全局学习率 1 × 1 0 − 5 1 × 10^{−5} 1×105
θ \theta θ 模型参数 -

对于模型算法过程确实是不复杂,所以打算按照图片中的行号一行行说明:

  1. 循环epoch
  2. 循环数据集,每次产生一个batch_size大小的数据
  3. 生成一个扰动 δ \delta δ, δ \delta δ服从均作为0,方差为1
  4. 循环K次,理论K越大效果越好,实际使用K=1,减少计算量
  5. 计算实际输入的输出和对抗样本的实际输入的DL散度Loss,并计算梯度
  6. 扰动正则化
  7. 循环K次结束
  8. 计算模型的Loss(带标签数据losss+虚拟对抗Loss)计算梯度更新参数,α是增强对抗学习的比例,预训练设置为10,下游任务设置为1。

下图展示了虚拟对抗训练的过程,我们可以看出对抗样本是由扰动加到输入的Embed空间上得到,然后原始输入和对抗样本分别计算得到两个输出,原始输出与标签计算得到Loss,对抗样本需要和原始输出计算得到Adv Loss,最后我们需要的是最小化Loss,最大化Adv Loss,最后我们的目标是:
m i n θ E ( x , y ) D [ l ( f ( x ; θ ) , y ) + α m a x δ l ( f ( x + δ ; θ ) , f ( x ; θ ) ) ] min_{\theta}E_{(x,y) D}[l(f(x;\theta),y)+αmax_{δ}l(f(x+δ;\theta),f(x; θ))] minθE(x,y)D[l(f(x;θ),y)+αmaxδl(f(x+δ;θ),f(x;θ))]
虚拟对抗训练的模型流程

代码干货

代码已经开源,项目是以robert进行了实验,我们只需要关心 adv_masked_lm.py 和 adv_masked_lm_task.py 这两个文件。

adv_masked_lm.py:虚拟对抗训练代码
adv_masked_lm_task.py:训练mlm模型,其中包括超参数的设置
在这里插入图片描述

虚拟对抗训练代码

本人使用中是剥离出adv_masked_lm.py,方便能在torch中使用。

讯享网import torch import torch.nn.functional as F def kl(inputs, targets, reduction="sum"): """ 计算kl散度 inputs:tensor,logits targets:tensor,logits """ loss = F.kl_div(F.log_softmax(inputs, dim=-1), F.softmax(targets, dim=-1), reduction=reduction) return loss def adv_project(grad, norm_type='inf', eps=1e-6): """ L0,L1,L2正则,对于扰动计算 """ if norm_type == 'l2': direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps) elif norm_type == 'l1': direction = grad.sign() else: direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps) return direction def virtual_adversarial_training(model, hidden_status, token_type_ids, attention_mask, logits): """ 虚拟对抗式训练 model: nn.Module, 模型 hidden_status:tensor,input的embedded表示 token_type_ids:tensor,bert中的token_type_ids,A B 句子 attention_mask:tensor,bert中的attention_mask,对paddding mask logits:tensor,input的输出 """ embed = hidden_status # 初始扰动 r noise = embed.data.new(embed.size()).normal_(0, 1) * 1e-5 noise.requires_grad_() # x + r new_embed = embed.data.detach() + noise adv_output = model(inputs_embeds=new_embed, token_type_ids=token_type_ids, attention_mask=attention_mask) adv_logits = adv_output[0] adv_loss = kl(adv_logits, logits.detach(), reduction="batchmean") delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True) norm = delta_grad.norm() # 梯度消失,退出 if torch.isnan(norm) or torch.isinf(norm): return None # line 6 inner sum noise = noise + delta_grad * 1e-3 # line 6 projection noise = adv_project(noise, norm_type='l2', eps=1e-6) new_embed = embed.data.detach() + noise new_embed = new_embed.detach() # 在进行一次训练 adv_output = model(inputs_embeds=new_embed, token_type_ids=token_type_ids, attention_mask=attention_mask) adv_logits = adv_output[0] adv_loss_f = kl(adv_logits, logits.detach()) adv_loss_b = kl(logits, adv_logits.detach()) # 在预训练时设置为10,下游任务设置为1 adv_loss = (adv_loss_f + adv_loss_b) * 1 return adv_loss 

使用方法

以下是使用nezha-bert训练的调用代码:

for input_ids, token_type_ids, attention_mask, output_ids, _ in tqdm(train_loader): step += 1 input_ids = input_ids.long().to(device) token_type_ids = token_type_ids.long().to(device) attention_mask = attention_mask.long().to(device) output_ids = output_ids.long().to(device) optimizer.zero_grad() # 混合精度计算,训练速度接近提高了1/2 with autocast(): output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=output_ids) loss = output[0] if args.use_adv == 'vat': logits = output[1] hidden_status = output[2][0] adv_loss = virtual_adversarial_training(model, hidden_status, token_type_ids, attention_mask, logits) if adv_loss: train_adv_loss += adv_loss loss = adv_loss * 10 + loss train_loss += loss loss.backward() optimizer.step() 

实验和结论

本人再使用nezha和vat的训练相似度计算模型auc能达到97.2%,acc达到91%,对比了没用使用vat,最终auc没能上97%,acc在89%左右,且训练过程中loss波动较大,所以证明了vat在训练过程中是有效的。

小讯
上一篇 2025-01-28 17:56
下一篇 2025-02-24 21:09

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/56494.html