Last active
April 19, 2024 15:54
-
-
Save jxmorris12/58659345e670613bdd7408b31ac6df78 to your computer and use it in GitHub Desktop.
verify parameter weights & gradients in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: | |
if hasattr(model, "module"): | |
model = model.module | |
world_size = get_world_size() | |
for name, param in model.named_parameters(): | |
gathered_param = gather(param).reshape((world_size, -1)) | |
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() | |
rank_params_eq = (absolute_diffs < atol).all() | |
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" | |
################################################################################################################### | |
gathered_param_grad = gather(param.grad).reshape((world_size, -1)) | |
absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() | |
rank_grad_params_eq = (absolute_grad_diffs < atol).all() | |
assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" | |
################################################################################################################### | |
print0("Verified DDP parameter correctness ✅") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment