Skip to content

Instantly share code, notes, and snippets.

@hzhu212
Last active March 18, 2025 01:46
Show Gist options
  • Save hzhu212/caa8dd36af6ef84097b414853e28739a to your computer and use it in GitHub Desktop.
Save hzhu212/caa8dd36af6ef84097b414853e28739a to your computer and use it in GitHub Desktop.

在分类问题中,FocalLoss 是一种解决标签分布不均的方法,并使模型专注于学习困难样本。例如,在医学影像检测中,负样本数量显著大于正样本。FocalLoss 原始论文中只讨论了二分类任务,缺少针对多分类问题的相应实现,这里记录一下实现方法。

方法1

class MultiClassFocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, alpha: torch.Tensor = None, reduction: str = 'mean', ignore_index: int = -100):
        """
        Args:
            gamma (float): 调制因子,越大越关注困难样本 (γ >= 0)。
            alpha (Tensor): 类别权重,形状为 [C],C为类别数。若为None,则无类别权重。
            reduction (str): 损失聚合方式,'mean' 或 'sum'。
            ignore_index (int): 忽略的类别标签索引。
        """
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        """
        Args:
            logits (Tensor): 模型输出,形状为 [B, C] 或 [B, C, D1, D2...](如图像分割)
            targets (Tensor): 真实标签,形状为 [B] 或 [B, D1, D2...],元素为类别索引。
        Returns:
            loss (Tensor): 计算后的损失值
        """
        # 处理高维输入(如图像分割)
        if logits.dim() > 2:
            logits = logits.view(logits.size(0), logits.size(1), -1)  # [B, C, D1*D2...]
            logits = logits.transpose(1, 2)  # [B, D1*D2..., C]
            logits = logits.contiguous().view(-1, logits.size(-1))  # [B*D1*D2..., C]
            targets = targets.view(-1)  # [B*D1*D2...]

        # 过滤被忽略的索引
        valid_mask = targets != self.ignore_index
        targets = targets[valid_mask]
        logits = logits[valid_mask]

        if targets.numel() == 0:  # 全部被忽略
            return torch.tensor(0.0).to(logits.device)

        log_softmax = F.log_softmax(logits, dim=-1) # [B*D1*D2..., C]
        ce_loss = -log_softmax.gather(1, targets.view(-1, 1)).squeeze()  # 交叉熵损失 # [B*D1*D2...]

        # 计算 p_t = softmax 后的真实类别概率
        pt = torch.exp(-ce_loss)  # p_t = exp(-ce_loss)

        # 计算 Focal Loss 的调制因子
        focal_term = (1 - pt) ** self.gamma

        # 应用类别权重 alpha
        if self.alpha is not None:
            alpha = self.alpha.to(logits.device)[targets]  # 获取每个样本的 alpha
            focal_loss = alpha * focal_term * ce_loss
        else:
            focal_loss = focal_term * ce_loss

        # 聚合损失
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

    @staticmethod
    def compute_class_weights(class_rates: torch.Tensor, beta: float = 1.0):
        """
        Args:
            class_rates (Tensor): 形状为 [C] 的类别频率
            beta (float): 平滑系数,当 beta 为 1 时类别权重为其频率的倒数,越大对大类别的抑制越重
        Returns:
            alpha (Tensor): 形状为 [C] 的类别权重
        """
        alpha = 1.0 / class_rates ** beta
        alpha = alpha / alpha.mean()  # 归一化
        return alpha

方法2

然而,针对上述实现方案,有人持不同的看法,例如这篇知乎文章 认为将 FocalLoss 简单地从二分类扩展到多分类是错误的,主张把多分类问题看作多个二分类问题。下面是参照这种想法的实现。

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean', ignore_index=-100):
        """
        Args:
            alpha (Tensor, optional): 类别权重张量(如 [0.2, 0.3, 0.5]),默认 None 表示等权重
            gamma (float): 聚焦参数,抑制易分样本的损失贡献
            reduction (str): 损失聚合方式,'mean' 或 'sum'
            ignore_index (int): 需要忽略的标签索引(默认 -100,与 PyTorch 一致)
        """
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        """
        Args:
            inputs (Tensor): 模型输出的 logits,形状为 [batch_size, num_classes]
            targets (Tensor): 目标标签,形状为 [batch_size]
        """
        # 生成忽略 mask
        mask = targets != self.ignore_index
        targets = targets[mask]  # 过滤忽略的标签
        inputs = inputs[mask, :]  # 过滤对应位置的 logits

        # 如果没有有效样本(如全部被忽略),返回 0
        if targets.numel() == 0:
            return torch.tensor(0.0, device=inputs.device)

        # FocalLoss 应用于多分类任务,相当于拆解成多个二分类任务
        targets = F.one_hot(targets, num_classes=inputs.size(1)).float() # [batch_size, num_classes]
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') # [batch_size, num_classes]
        
        # 每个样本被正确二分类的概率
        p = torch.sigmoid(inputs) # [batch_size, num_classes]
        # 注意,当 inputs 值较大时,sigmoid 会输出 0/1 极端值,当 gamma < 1 时,后续计算梯度时会出现 1/0 = inf 的情况,导致模型无法训练。需要进行平滑操作:
        smooth_factor = 0.001 # 平滑因子
        smooth_p = p * (1 - smooth_factor) + smooth_factor / 2
        p_t = smooth_p * targets + (1 - smooth_p) * (1 - targets) # [batch_size, num_classes]

        # 计算 Focal Loss
        loss = ce_loss * ((1 - p_t) ** self.gamma) # [batch_size, num_classes]

        if self.alpha is not None and self.alpha > 0:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) # [batch_size, num_classes]
            focal_loss = alpha_t * loss # [batch_size, num_classes]
        else:
            focal_loss = loss

        # 聚合损失
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss
        
        
# 使用示例
# 与使用 nn.CrossEntropyLoss 的方法完全一致,略
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment