在分类问题中,FocalLoss 是一种解决标签分布不均的方法,并使模型专注于学习困难样本。例如,在医学影像检测中,负样本数量显著大于正样本。FocalLoss 原始论文中只讨论了二分类任务,缺少针对多分类问题的相应实现,这里记录一下实现方法。
class MultiClassFocalLoss(nn.Module):
def __init__(self, gamma: float = 2.0, alpha: torch.Tensor = None, reduction: str = 'mean', ignore_index: int = -100):
"""
Args:
gamma (float): 调制因子,越大越关注困难样本 (γ >= 0)。
alpha (Tensor): 类别权重,形状为 [C],C为类别数。若为None,则无类别权重。
reduction (str): 损失聚合方式,'mean' 或 'sum'。
ignore_index (int): 忽略的类别标签索引。
"""
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, logits: torch.Tensor, targets: torch.Tensor):
"""
Args:
logits (Tensor): 模型输出,形状为 [B, C] 或 [B, C, D1, D2...](如图像分割)
targets (Tensor): 真实标签,形状为 [B] 或 [B, D1, D2...],元素为类别索引。
Returns:
loss (Tensor): 计算后的损失值
"""
# 处理高维输入(如图像分割)
if logits.dim() > 2:
logits = logits.view(logits.size(0), logits.size(1), -1) # [B, C, D1*D2...]
logits = logits.transpose(1, 2) # [B, D1*D2..., C]
logits = logits.contiguous().view(-1, logits.size(-1)) # [B*D1*D2..., C]
targets = targets.view(-1) # [B*D1*D2...]
# 过滤被忽略的索引
valid_mask = targets != self.ignore_index
targets = targets[valid_mask]
logits = logits[valid_mask]
if targets.numel() == 0: # 全部被忽略
return torch.tensor(0.0).to(logits.device)
log_softmax = F.log_softmax(logits, dim=-1) # [B*D1*D2..., C]
ce_loss = -log_softmax.gather(1, targets.view(-1, 1)).squeeze() # 交叉熵损失 # [B*D1*D2...]
# 计算 p_t = softmax 后的真实类别概率
pt = torch.exp(-ce_loss) # p_t = exp(-ce_loss)
# 计算 Focal Loss 的调制因子
focal_term = (1 - pt) ** self.gamma
# 应用类别权重 alpha
if self.alpha is not None:
alpha = self.alpha.to(logits.device)[targets] # 获取每个样本的 alpha
focal_loss = alpha * focal_term * ce_loss
else:
focal_loss = focal_term * ce_loss
# 聚合损失
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
@staticmethod
def compute_class_weights(class_rates: torch.Tensor, beta: float = 1.0):
"""
Args:
class_rates (Tensor): 形状为 [C] 的类别频率
beta (float): 平滑系数,当 beta 为 1 时类别权重为其频率的倒数,越大对大类别的抑制越重
Returns:
alpha (Tensor): 形状为 [C] 的类别权重
"""
alpha = 1.0 / class_rates ** beta
alpha = alpha / alpha.mean() # 归一化
return alpha
然而,针对上述实现方案,有人持不同的看法,例如这篇知乎文章 认为将 FocalLoss 简单地从二分类扩展到多分类是错误的,主张把多分类问题看作多个二分类问题。下面是参照这种想法的实现。
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, reduction='mean', ignore_index=-100):
"""
Args:
alpha (Tensor, optional): 类别权重张量(如 [0.2, 0.3, 0.5]),默认 None 表示等权重
gamma (float): 聚焦参数,抑制易分样本的损失贡献
reduction (str): 损失聚合方式,'mean' 或 'sum'
ignore_index (int): 需要忽略的标签索引(默认 -100,与 PyTorch 一致)
"""
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, inputs, targets):
"""
Args:
inputs (Tensor): 模型输出的 logits,形状为 [batch_size, num_classes]
targets (Tensor): 目标标签,形状为 [batch_size]
"""
# 生成忽略 mask
mask = targets != self.ignore_index
targets = targets[mask] # 过滤忽略的标签
inputs = inputs[mask, :] # 过滤对应位置的 logits
# 如果没有有效样本(如全部被忽略),返回 0
if targets.numel() == 0:
return torch.tensor(0.0, device=inputs.device)
# FocalLoss 应用于多分类任务,相当于拆解成多个二分类任务
targets = F.one_hot(targets, num_classes=inputs.size(1)).float() # [batch_size, num_classes]
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') # [batch_size, num_classes]
# 每个样本被正确二分类的概率
p = torch.sigmoid(inputs) # [batch_size, num_classes]
# 注意,当 inputs 值较大时,sigmoid 会输出 0/1 极端值,当 gamma < 1 时,后续计算梯度时会出现 1/0 = inf 的情况,导致模型无法训练。需要进行平滑操作:
smooth_factor = 0.001 # 平滑因子
smooth_p = p * (1 - smooth_factor) + smooth_factor / 2
p_t = smooth_p * targets + (1 - smooth_p) * (1 - targets) # [batch_size, num_classes]
# 计算 Focal Loss
loss = ce_loss * ((1 - p_t) ** self.gamma) # [batch_size, num_classes]
if self.alpha is not None and self.alpha > 0:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) # [batch_size, num_classes]
focal_loss = alpha_t * loss # [batch_size, num_classes]
else:
focal_loss = loss
# 聚合损失
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
return focal_loss
# 使用示例
# 与使用 nn.CrossEntropyLoss 的方法完全一致,略