Skip to content

Instantly share code, notes, and snippets.

@hzhu212
Last active March 18, 2025 02:05
Show Gist options
  • Save hzhu212/2e7d651b575d56790f108ef94a01a4cf to your computer and use it in GitHub Desktop.
Save hzhu212/2e7d651b575d56790f108ef94a01a4cf to your computer and use it in GitHub Desktop.
多目标学习/multi-task learning

在多目标学习中,其中一个难题是如何平衡不同目标之间的 loss 及其学习速率,从而避免模型被其中一个目标所主导,导致其他目标不能得到充分学习。这里简要介绍几种方法。

1. 简单平均法

直接将多个 Loss 加权平均,权重为超参数(通常初始化为等权重)。

公式

$$\mathcal{L}{\text{total}} = \sum{i=1}^N w_i \mathcal{L}_i \quad \text{(默认 } w_i=1.0 \text{)}$$

适用场景

  • 任务重要性相近且损失量级相似
  • 快速验证模型可行性

示例代码

# 假设 loss1 和 loss2 是两个任务的损失
w1, w2 = 1.0, 1.0  # 可调整为超参数
total_loss = w1 * loss1 + w2 * loss2

不足

  • 模型容易被其中 loss 梯度较大的任务主导。
  • 如果手动设置权重可以一定程度避免单 loss 主导,但不够灵活通用,需要先观察一下不同任务的 loss 及其梯度的数量级,然后才能得到权重。

2. 基于不确定性加权(Uncertainty Weighting)

通过模型自动学习每个任务的噪声水平(不确定性)作为权重。

核心思想

  • 任务噪声越大(不确定性越高),权重越小
  • 通过最大化高斯似然推导权重

公式
$$ \mathcal{L}{\text{total}} = \sum{i=1}^N \frac{1}{\sigma_i^2} \mathcal{L}_i + \log \sigma_i^2 $$

其中 $\sigma_i$ 是可学习的参数,初始化为0。

代码实现

class UncertaintyWeightedLoss(nn.Module):
    def __init__(self, num_tasks):
        super().__init__()
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))  # 学习 log(sigma^2)

    def forward(self, losses):
        total_loss = 0.0
        for i, loss in enumerate(losses):
            precision = torch.exp(-self.log_vars[i])  # 1/sigma^2
            total_loss += precision * loss + self.log_vars[i]
        return total_loss

# 使用示例
loss_fn = UncertaintyWeightedLoss(num_tasks=2)
total_loss = loss_fn([loss1, loss2])

3. 梯度归一化(GradNorm)

动态调整权重,使得各任务的梯度量级相似。

步骤

  1. 计算每个任务权重的梯度范数
  2. 调整权重使各任务的梯度范数趋近于某个共同目标(如平均范数)

代码框架

class GradNormLoss(nn.Module):
    def __init__(self, num_tasks: int, shared_parameters: Union[torch.Tensor, Sequence[torch.Tensor]], alpha: float = 1.5):
        """
        Args:
            num_tasks (int): 任务数量
            shared_parameters (iterable): 共享参数的列表
            alpha (float): 平衡强度系数,越大任务差异越大
        """
        super().__init__()
        self.num_tasks = num_tasks
        self.shared_parameters = list(shared_parameters)
        self.alpha = alpha
        self.initial_losses = None
        # 使用 log_weights 而不是 weights,以避免权重被学习成负值
        # self.weights = nn.Parameter(torch.ones(num_tasks))
        self.log_weights = nn.Parameter(torch.zeros(num_tasks))
        self.weights = torch.exp(self.log_weights)

    def forward(self, task_losses: torch.Tensor, training: bool = True):
        """
        Args:
            task_losses (Tensor): 各任务的损失值列表,长度为 num_tasks
        Returns:
            total_loss (Tensor): 加权后的总损失(需调用 backward())
        """
        if self.initial_losses is None:
            self.initial_losses = task_losses.detach().clone()

        self.weights = torch.exp(self.log_weights)

        # 计算加权损失
        weighted_loss = (task_losses * self.weights).sum()

        # 在 evaluate 模式下无法计算梯度损失
        if not training:
            return weighted_loss
        
        # 计算梯度范数
        shared_grads = []
        for loss in task_losses:
            # 计算当前任务对共享参数的梯度
            grad = torch.autograd.grad(loss, self.shared_parameters, retain_graph=True, allow_unused=True)
            # 拼接梯度并计算L2范数
            grad_norm = torch.norm(torch.cat([g.view(-1) for g in grad if g is not None]))  # 处理无梯度情况
            shared_grads.append(grad_norm)
        shared_grads = torch.stack(shared_grads)  # [num_tasks]
        
        # 计算目标梯度范数(全局参考值)
        avg_grad = shared_grads.mean()
        
        # 计算相对逆训练速率
        ratios = task_losses.detach() / self.initial_losses
        ratios = ratios / ratios.mean()
        
        # 计算梯度损失(GradNorm Loss)
        target_grads = avg_grad * (ratios ** self.alpha)
        grad_loss = torch.sum(torch.abs(shared_grads - target_grads))
        
        # 合并总损失
        total_loss = weighted_loss + grad_loss
        
        # 权重归一化(保持权重总和为 num_tasks)
        self.weights.data = self.weights.data * self.num_tasks / self.weights.data.sum()

        return total_loss
  
        
# 使用示例
...
# 模型参数和 GradNorm 权重都需要纳入梯度下降范围
criterion = MultiTaskLoss() # forward 方法返回每个 task 的 loss
gn = GradNormLoss(num_of_task=3, shared_parameters=model.linear2.parameters())
optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': gn.parameters(), 'lr': 1e-3}  # 可单独设置权重学习率,一般为模型学习率的 10 倍
], lr=1e-4)

for epoch in range(100):
    for batch in dataloader:
        optimizer.zero_grad()
        # 前向传播
        multi_pred_logits = model(batch)
        losses = criterion(multi_pred_logits, multi_labels)
        total_loss = gn(losses, training=True)
        
        # 反向传播
        total_loss.backward(retain_graph=True)
        
        # 梯度裁剪,可选
        nn.utils.clip_grad_norm_(list(model.parameters()) + list(gn.parameters()), 1.0)
        
        # 更新参数
        optimizer.step()

4. 动态任务优先级(DWA, Dynamic Weight Average)

根据任务的学习速度调整权重:

  • 学习速度慢的任务(损失下降慢)分配更高权重
  • 参考前一时段的损失比值

公式
$$ w_i^{(t)} = \frac{N \exp(r_i^{(t)} / T)}{\sum_j \exp(r_j^{(t)} / T)} \quad \text{(} r_i^{(t)} = \frac{\mathcal{L}_i^{(t-1)}}{\mathcal{L}_i^{(t-2)}} \text{)} $$ 其中 $T$ 是温度参数,控制权重分布的平滑程度。

代码实现

class DWA:
    def __init__(self, num_tasks, temp=2.0):
        self.num_tasks = num_tasks
        self.temp = temp
        self.prev_losses = None  # 保存前两个时间步的损失

    def compute_weights(self, current_losses):
        if self.prev_losses is None:  # 初始化为等权重
            self.prev_losses = current_losses.detach()
            return torch.ones(self.num_tasks)
        
        # 计算各任务的损失变化速率 r_i
        rates = current_losses.detach() / self.prev_losses.detach()
        self.prev_losses = current_losses.detach()
        
        # Softmax 分配权重
        weights = F.softmax(rates / self.temp, dim=0) * self.num_tasks
        return weights
        
        
# 使用示例
...
criterion = MultiTaskLoss() # forward 方法返回每个 task 的 loss
dwa = DWA(num_tasks=3)
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(100):
    for batch in dataloader:
        optimizer.zero_grad()
        # 前向传播
        multi_pred_logits = model(batch)
        losses = criterion(multi_pred_logits, multi_labels)
        weights = dwa.compute_weights(losses)
        total_loss = (weights * losses).sum()
        
        # 反向传播
        total_loss.backward()

        # 更新参数
        optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment