Created
February 6, 2025 04:34
-
-
Save youkaichao/9d92875e584f3161001944c3348d1ce2 to your computer and use it in GitHub Desktop.
gloo v.s. nccl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
use_nccl = False | |
dist.init_process_group(backend="nccl" if use_nccl else "gloo") | |
rank = dist.get_rank() | |
torch.cuda.set_device(rank % 8) | |
def try_nccl(): | |
signal = 0 | |
signal_tensor = torch.tensor(signal, device="cuda", dtype=torch.long) | |
dist.all_reduce(signal_tensor, op=dist.ReduceOp.MAX) | |
reduced_signal = signal_tensor.item() | |
return reduced_signal | |
def try_gloo(): | |
signal = 0 | |
signal_tensor = torch.tensor(signal, device="cpu", dtype=torch.long) | |
dist.all_reduce(signal_tensor, op=dist.ReduceOp.MAX) | |
reduced_signal = signal_tensor.item() | |
return reduced_signal | |
func = try_nccl if use_nccl else try_gloo | |
# warm up | |
for i in range(10): | |
func() | |
# measure the time | |
import time | |
start = time.time() | |
for i in range(1000): | |
func() | |
end = time.time() | |
per_iter = (end - start) / 1000 * 1000 # in milliseconds | |
print(f"per iteration time: {per_iter:.2f} ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
test command:
torchrun --nproc-per-node=8 test.py
gloo:
nccl: