在多目标学习中,其中一个难题是如何平衡不同目标之间的 loss 及其学习速率,从而避免模型被其中一个目标所主导,导致其他目标不能得到充分学习。这里简要介绍几种方法。
直接将多个 Loss 加权平均,权重为超参数(通常初始化为等权重)。
公式:
$$\mathcal{L}{\text{total}} = \sum{i=1}^N w_i \mathcal{L}_i \quad \text{(默认 } w_i=1.0 \text{)}$$
适用场景:
- 任务重要性相近且损失量级相似
- 快速验证模型可行性
示例代码:
# 假设 loss1 和 loss2 是两个任务的损失
w1, w2 = 1.0, 1.0 # 可调整为超参数
total_loss = w1 * loss1 + w2 * loss2
不足:
- 模型容易被其中 loss 梯度较大的任务主导。
- 如果手动设置权重可以一定程度避免单 loss 主导,但不够灵活通用,需要先观察一下不同任务的 loss 及其梯度的数量级,然后才能得到权重。
通过模型自动学习每个任务的噪声水平(不确定性)作为权重。
核心思想:
- 任务噪声越大(不确定性越高),权重越小
- 通过最大化高斯似然推导权重
公式:
$$
\mathcal{L}{\text{total}} = \sum{i=1}^N \frac{1}{\sigma_i^2} \mathcal{L}_i + \log \sigma_i^2
$$
其中
代码实现:
class UncertaintyWeightedLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
self.log_vars = nn.Parameter(torch.zeros(num_tasks)) # 学习 log(sigma^2)
def forward(self, losses):
total_loss = 0.0
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i]) # 1/sigma^2
total_loss += precision * loss + self.log_vars[i]
return total_loss
# 使用示例
loss_fn = UncertaintyWeightedLoss(num_tasks=2)
total_loss = loss_fn([loss1, loss2])
动态调整权重,使得各任务的梯度量级相似。
步骤:
- 计算每个任务权重的梯度范数
- 调整权重使各任务的梯度范数趋近于某个共同目标(如平均范数)
代码框架:
class GradNormLoss(nn.Module):
def __init__(self, num_tasks: int, shared_parameters: Union[torch.Tensor, Sequence[torch.Tensor]], alpha: float = 1.5):
"""
Args:
num_tasks (int): 任务数量
shared_parameters (iterable): 共享参数的列表
alpha (float): 平衡强度系数,越大任务差异越大
"""
super().__init__()
self.num_tasks = num_tasks
self.shared_parameters = list(shared_parameters)
self.alpha = alpha
self.initial_losses = None
# 使用 log_weights 而不是 weights,以避免权重被学习成负值
# self.weights = nn.Parameter(torch.ones(num_tasks))
self.log_weights = nn.Parameter(torch.zeros(num_tasks))
self.weights = torch.exp(self.log_weights)
def forward(self, task_losses: torch.Tensor, training: bool = True):
"""
Args:
task_losses (Tensor): 各任务的损失值列表,长度为 num_tasks
Returns:
total_loss (Tensor): 加权后的总损失(需调用 backward())
"""
if self.initial_losses is None:
self.initial_losses = task_losses.detach().clone()
self.weights = torch.exp(self.log_weights)
# 计算加权损失
weighted_loss = (task_losses * self.weights).sum()
# 在 evaluate 模式下无法计算梯度损失
if not training:
return weighted_loss
# 计算梯度范数
shared_grads = []
for loss in task_losses:
# 计算当前任务对共享参数的梯度
grad = torch.autograd.grad(loss, self.shared_parameters, retain_graph=True, allow_unused=True)
# 拼接梯度并计算L2范数
grad_norm = torch.norm(torch.cat([g.view(-1) for g in grad if g is not None])) # 处理无梯度情况
shared_grads.append(grad_norm)
shared_grads = torch.stack(shared_grads) # [num_tasks]
# 计算目标梯度范数(全局参考值)
avg_grad = shared_grads.mean()
# 计算相对逆训练速率
ratios = task_losses.detach() / self.initial_losses
ratios = ratios / ratios.mean()
# 计算梯度损失(GradNorm Loss)
target_grads = avg_grad * (ratios ** self.alpha)
grad_loss = torch.sum(torch.abs(shared_grads - target_grads))
# 合并总损失
total_loss = weighted_loss + grad_loss
# 权重归一化(保持权重总和为 num_tasks)
self.weights.data = self.weights.data * self.num_tasks / self.weights.data.sum()
return total_loss
# 使用示例
...
# 模型参数和 GradNorm 权重都需要纳入梯度下降范围
criterion = MultiTaskLoss() # forward 方法返回每个 task 的 loss
gn = GradNormLoss(num_of_task=3, shared_parameters=model.linear2.parameters())
optimizer = torch.optim.Adam([
{'params': model.parameters()},
{'params': gn.parameters(), 'lr': 1e-3} # 可单独设置权重学习率,一般为模型学习率的 10 倍
], lr=1e-4)
for epoch in range(100):
for batch in dataloader:
optimizer.zero_grad()
# 前向传播
multi_pred_logits = model(batch)
losses = criterion(multi_pred_logits, multi_labels)
total_loss = gn(losses, training=True)
# 反向传播
total_loss.backward(retain_graph=True)
# 梯度裁剪,可选
nn.utils.clip_grad_norm_(list(model.parameters()) + list(gn.parameters()), 1.0)
# 更新参数
optimizer.step()
根据任务的学习速度调整权重:
- 学习速度慢的任务(损失下降慢)分配更高权重
- 参考前一时段的损失比值
公式:
$$
w_i^{(t)} = \frac{N \exp(r_i^{(t)} / T)}{\sum_j \exp(r_j^{(t)} / T)} \quad \text{(} r_i^{(t)} = \frac{\mathcal{L}_i^{(t-1)}}{\mathcal{L}_i^{(t-2)}} \text{)}
$$
其中
代码实现:
class DWA:
def __init__(self, num_tasks, temp=2.0):
self.num_tasks = num_tasks
self.temp = temp
self.prev_losses = None # 保存前两个时间步的损失
def compute_weights(self, current_losses):
if self.prev_losses is None: # 初始化为等权重
self.prev_losses = current_losses.detach()
return torch.ones(self.num_tasks)
# 计算各任务的损失变化速率 r_i
rates = current_losses.detach() / self.prev_losses.detach()
self.prev_losses = current_losses.detach()
# Softmax 分配权重
weights = F.softmax(rates / self.temp, dim=0) * self.num_tasks
return weights
# 使用示例
...
criterion = MultiTaskLoss() # forward 方法返回每个 task 的 loss
dwa = DWA(num_tasks=3)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
for batch in dataloader:
optimizer.zero_grad()
# 前向传播
multi_pred_logits = model(batch)
losses = criterion(multi_pred_logits, multi_labels)
weights = dwa.compute_weights(losses)
total_loss = (weights * losses).sum()
# 反向传播
total_loss.backward()
# 更新参数
optimizer.step()