Modified
This report provides a definitive, actionable, and unambiguous guide for migrating the Hierarchical Reasoning Model (HRM) to ROCm, specifically targeting AMD MI300X GPUs. All previous uncertainties and 'if' statements have been resolved to provide clear instructions for developers.
Current CUDA Dependencies:
The README.md
explicitly outlines the installation of CUDA and PyTorch with CUDA support, along with FlashAttention, which is a CUDA-dependent library.
### Prerequisites ⚙️
Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
```shell
# Install CUDA 12.6
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run
wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
sudo sh cuda_installer.run --silent --toolkit --override
export CUDA_HOME=/usr/local/cuda-12.6
# Install PyTorch with CUDA 12.6
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126
pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
# Additional packages for building extensions
pip3 install packaging ninja wheel setuptools setuptools-scm
Then install FlashAttention. For Hopper GPUs, install FlashAttention 3
git clone [email protected]:Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
For Ampere or earlier GPUs, install FlashAttenion 2
pip3 install flash-attn
**Required Changes for ROCm on MI300X:**
1. **CUDA Installation Removal:** Remove all instructions related to CUDA installation (`wget` and `sudo sh cuda_installer.run`).
2. **PyTorch for ROCm:** Update the PyTorch installation command to use a ROCm-compatible wheel. For MI300X, ensure you are using a PyTorch build that supports ROCm 5.4.2 or later. The `PYTORCH_INDEX_URL` needs to point to the ROCm build of PyTorch. For example, for ROCm 5.4.2, it would be `https://download.pytorch.org/whl/rocm5.4.2`.
3. **FlashAttention Replacement:** Remove the FlashAttention installation instructions. For MI300X, the recommended approach is to use the official PyTorch `torch.nn.functional.scaled_dot_product_attention` (SDPA) which leverages optimized ROCm kernels for FlashAttention-like performance. This is the most stable and officially supported method for MI300X.
* **No separate FlashAttention library installation is required.** The necessary optimizations are integrated directly into PyTorch when built with ROCm support.
4. **Environment Variables:** Change `export CUDA_HOME=/usr/local/cuda-12.6` to `export ROCM_PATH=/opt/rocm-x.y.z` (replace `x.y.z` with the specific ROCm version installed on your system, e.g., `/opt/rocm-5.4.2`).
## 2. `requirements.txt`
**Current CUDA Dependencies:**
The `requirements.txt` file lists `torch`, but the specific CUDA dependency is handled by the `index-url` during PyTorch installation, as seen in `README.md`. There are no explicit CUDA-specific libraries listed here.
torch adam-atan2 einops tqdm coolname pydantic argdantic wandb omegaconf hydra-core huggingface_hub
**Required Changes for ROCm on MI300X:**
No direct changes are needed in `requirements.txt` itself. The change will be in how `torch` is installed (as described for `README.md`). Since `torch.nn.functional.scaled_dot_product_attention` is part of PyTorch, no additional entries are needed for FlashAttention.
## 3. `pretrain.py`
**Current CUDA Dependencies:**
This file contains explicit references to CUDA for device placement and distributed training.
* **`torch.device("cuda")` and `.cuda()` calls:** These are used to move tensors and models to the CUDA device.
* **`dist.init_process_group(backend="nccl")`:** NCCL (NVIDIA Collective Communications Library) is a CUDA-specific backend for distributed training.
* **`torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))`:** Sets the CUDA device for distributed training.
* **`map_location="cuda"`:** Used when loading checkpoints, explicitly mapping them to CUDA.
```python
with torch.device("cuda"):
model: nn.Module = model_cls(model_cfg)
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
if "DISABLE_COMPILE" not in os.environ:
model = torch.compile(model, dynamic=False) # type: ignore
# Broadcast parameters from rank 0
if world_size > 1:
with torch.no_grad():
for param in list(model.parameters()) + list(model.buffers()):
dist.broadcast(param, src=0)
# To device
batch = {k: v.cuda() for k, v in batch.items()}
# Init carry if it is None
if train_state.carry is None:
with torch.device("cuda"):
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
# Allreduce
if world_size > 1:
for param in train_state.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad)
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
Required Changes for ROCm on MI300X:
Modern PyTorch with ROCm support is designed to be largely compatible with existing CUDA code. For MI300X, torch.device("cuda")
and .cuda()
calls will correctly target the AMD GPU when PyTorch is built with ROCm support. Therefore, no changes are strictly required for torch.device("cuda")
or .cuda()
calls. These will function correctly on MI300X.
- Device Placement: No changes are required for
torch.device("cuda")
or.cuda()
calls. They will automatically utilize the MI300X GPU. For explicit clarity,torch.device("hip")
or.to("hip")
can be used, but it is not necessary for functionality. - Distributed Backend: Change
backend="nccl"
tobackend="rccl"
. RCCL (ROCm Collective Communications Library) is AMD"s equivalent to NCCL and provides optimized collective operations for multi-GPU communication on MI300X. This is crucial for multi-MI300X setups. - Set Device: Replace
torch.cuda.set_device
withtorch.hip.set_device
. This explicitly sets the HIP device for distributed training. map_location
: No changes are required formap_location="cuda"
when loading checkpoints. It will correctly map to the MI300X GPU.
Current CUDA Dependencies:
Similar to pretrain.py
, this file also contains explicit references to CUDA for distributed training and loading checkpoints.
dist.init_process_group(backend="nccl")
: NCCL backend for distributed training.torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
: Sets the CUDA device.map_location="cuda"
: Used when loading checkpoints.
RANK = 0
WORLD_SIZE = 1
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
# Models
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
# Try unwrap torch.compile
try:
train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
except:
train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
Required Changes for ROCm on MI300X:
As with pretrain.py
, torch.device("cuda")
and map_location="cuda"
will function correctly on MI300X. The primary changes are for the distributed backend and explicit device setting.
- Distributed Backend: Change
backend="nccl"
tobackend="rccl"
for optimal multi-MI300X performance. - Set Device: Replace
torch.cuda.set_device
withtorch.hip.set_device
. map_location
: No changes are required formap_location="cuda"
.
Current CUDA Dependencies:
This file has a direct dependency on flash_attn_interface
or flash_attn
, which are CUDA-optimized libraries for attention mechanisms.
try:
from flash_attn_interface import flash_attn_func # type: ignore[import]
except ImportError:
# Fallback to FlashAttention 2
from flash_attn import flash_attn_func # type: ignore[import]
Required Changes for ROCm on MI300X:
-
FlashAttention Replacement: For MI300X, the
flash_attn_func
calls must be replaced withtorch.nn.functional.scaled_dot_product_attention
(SDPA). This function is integrated into PyTorch and leverages highly optimized kernels for MI300X, providing FlashAttention-like performance without external dependencies.Detailed Steps for Integration within
models/layers.py
:-
Remove
flash_attn
imports: Delete thetry-except
block that importsflash_attn_func
. -
Replace
flash_attn_func
calls withF.scaled_dot_product_attention
: TheAttention
class'sforward
method needs to be updated.Original code in
Attention.forward
:# flash attn attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal) if isinstance(attn_output, tuple): # fa2 and fa3 compatibility attn_output = attn_output[0]
Modified code in
Attention.forward
:# Use PyTorch's native scaled_dot_product_attention for MI300X # This leverages optimized ROCm kernels for FlashAttention-like performance. attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=None, # SDPA handles causal mask internally dropout_p=0.0, # Dropout should be handled by the model, not SDPA directly unless needed is_causal=self.causal )
-
Head Dimension Compatibility: The
scaled_dot_product_attention
function in PyTorch on ROCm for MI300X handles various head dimensions efficiently, including those up to 256 and beyond, for both forward and backward passes. Therefore, the previous concerns abouthead_dim
limitations (e.g., 128 for backward pass) are not an issue when using PyTorch's native SDPA on MI300X. -
Sliding Window Attention: If sliding window attention is ever required in the future,
torch.nn.functional.scaled_dot_product_attention
does not directly support it. In such a scenario, a custom implementation or a different library specifically designed for sliding window attention on ROCm would be necessary. However, for the current HRM implementation, this is not a concern.
-
Current CUDA Dependencies:
This file defines the core Hierarchical Reasoning Model architecture. While it doesn"t directly import torch.cuda
, it implicitly relies on the underlying PyTorch backend for device operations. The forward_dtype
is set to bfloat16
, which is typically used with modern GPUs (including NVIDIA and AMD).
Required Changes for ROCm on MI300X:
No direct code changes are required in this file. All torch.Tensor
operations will automatically leverage the ROCm backend on MI300X when PyTorch is correctly configured.
Based on the file listing, other files like models/common.py
, models/losses.py
, models/sparse_embedding.py
, utils/functions.py
, and dataset/*.py
are primarily Python logic and data handling. They are unlikely to have direct CUDA dependencies and will function correctly with PyTorch configured for ROCm on MI300X.
The migration of HRM to ROCm on AMD MI300X GPUs is straightforward and primarily involves updating the PyTorch installation and replacing the explicit FlashAttention calls with PyTorch's native scaled_dot_product_attention
. The key steps are:
-
Update
README.md
:- Remove all CUDA installation instructions.
- Update PyTorch installation to use ROCm-compatible wheels (e.g.,
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/rocm5.4.2
). - Remove FlashAttention installation instructions.
- Change
export CUDA_HOME
toexport ROCM_PATH=/opt/rocm-x.y.z
.
-
Modify
pretrain.py
andevaluate.py
:- Change
dist.init_process_group(backend="nccl")
todist.init_process_group(backend="rccl")
. - Replace
torch.cuda.set_device
withtorch.hip.set_device
.
- Change
-
Modify
models/layers.py
:- Remove
flash_attn
import statements. - Replace
flash_attn_func(q=query, k=key, v=value, causal=self.causal)
withF.scaled_dot_product_attention(query, key, value, is_causal=self.causal)
.
- Remove
-
Testing and Validation:
- Thoroughly test the HRM model on MI300X GPUs to ensure correct functionality and performance. This includes verifying training stability, convergence, and inference speed.