NLP类别不均衡问题之loss大集合
数据派THU
共 3041字,需浏览 7分钟
·
2022-04-16 22:28
来源:PaperWeekly 本文约2300字,建议阅读9分钟
本文主要讨论了类别不均衡问题的解决办法,可分为数据层面的重采样及模型 loss 方面的改进。
数据层面:重采样,使得参与迭代计算的数据是均衡的;
模型层面:重加权,修改模型的 loss,在 loss 计算上,加大对少样本的 loss 奖励。
欠采样;
过采样;
SMOTE。
2. 模型层面的重加权
重加权主要指的是在 loss 计算阶段,通过设计 loss,调整类别的权值对 loss 的贡献。比较经典的 loss 改进应该是 Focal Loss, GHM Loss, Dice Loss。
def __init__(self, num_class, alpha=None, gamma=2, reduction='mean'):
super(MultiFocalLoss, self).__init__()
self.gamma = gamma
......
def forward(self, logit, target):
alpha = self.alpha.to(logit.device)
prob = F.softmax(logit, dim=1)
ori_shp = target.shape
target = target.view(-1, 1)
prob = prob.gather(1, target).view(-1) + self.smooth # avoid nan
logpt = torch.log(prob)
alpha_weight = alpha[target.squeeze().long()]
loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
if self.reduction == 'mean':
loss = loss.mean()
return loss
上面的 Focal Loss 注重了对 hard example 的学习,但不是所有的 hard example 都值得关注,有一些 hard example 很可能是离群点,这种离群点当然是不应该让模型关注的。
class GHM_Loss(nn.Module):
def __init__(self, bins, alpha):
super(GHM_Loss, self).__init__()
self._bins = bins
self._alpha = alpha
self._last_bin_count = None
def _g2bin(self, g):
# split to n bins
return torch.floor(g * (self._bins - 0.0001)).long()
def forward(self, x, target):
# compute value g
g = torch.abs(self._custom_loss_grad(x, target)).detach()
bin_idx = self._g2bin(g)
bin_count = torch.zeros((self._bins))
for i in range(self._bins):
# 计算落入bins的梯度模长数量
bin_count[i] = (bin_idx == i).sum().item()
N = (x.size(0) * x.size(1))
if self._last_bin_count is None:
self._last_bin_count = bin_count
else:
bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
self._last_bin_count = bin_count
nonempty_bins = (bin_count > 0).sum().item()
gd = bin_count * nonempty_bins
gd = torch.clamp(gd, min=0.0001)
beta = N / gd # 计算好样本的gd值
# 借由binary_cross_entropy_with_logits,gd值当作参数传入
return F.binary_cross_entropy_with_logits(x, target, weight=beta[bin_idx])
class DSCLoss(torch.nn.Module):
def __init__(self, alpha: float = 1.0, smooth: float = 1.0, reduction: str = "mean"):
super().__init__()
self.alpha = alpha
self.smooth = smooth
self.reduction = reduction
def forward(self, logits, targets):
probs = torch.softmax(logits, dim=1)
probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1))
probs_with_factor = ((1 - probs) ** self.alpha) * probs
loss = 1 - (2 * probs_with_factor + self.smooth) / (probs_with_factor + 1 + self.smooth)
if self.reduction == "mean":
return loss.mean()
总结
本文主要讨论了类别不均衡问题的解决办法,可分为数据层面的重采样及模型 loss 方面的改进,如 focal loss, dice loss 等。最后说一下实践下来的经验,由于不同数据集的数据分布特点各有不同,dice loss 以及 GHM loss 会出现些抖动、不稳定的情况。当不想挨个实践的时候,首推 focal loss,dice loss。
以上所有 Loss 的代码仅为逻辑参考,完整的代码及相关参考论文都在:
https://github.com/shuxinyin/NLP-Loss-Pytorch
编辑:王菁
校对:杨学俊
评论