Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Last active April 18, 2024 07:23
Show Gist options
  • Save yejingxin/4b7cb1199a7c3c33dbe96eb52a939814 to your computer and use it in GitHub Desktop.
Save yejingxin/4b7cb1199a7c3c33dbe96eb52a939814 to your computer and use it in GitHub Desktop.
Multislice TPU device count example
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple e2e example running device count on multislice."""
"""
# Dev box set up
ray start --head
DEV_BOX_IP=10.164.0.35
# set up slice to join ray cluster
# slice 0
gcloud compute tpus tpu-vm ssh yejingxin-tpu-slice0 \
--zone europe-west4-a --worker=all \
--command "pip install 'ray[default]' && .local/bin/ray start --address='$DEV_BOX_IP:6379' --resources='{\"yejingxin-tpu-slice_0\": 1}'" \
# slice1
gcloud compute tpus tpu-vm ssh yejingxin-tpu-slice1 \
--zone europe-west4-a --worker=all \
--command "pip install 'ray[default]' && .local/bin/ray start --address='$DEV_BOX_IP:6379' --resources='{\"yejingxin-tpu-slice_1\": 1}'" \
# check ray status
# python run_device_count.py
"""
import time
import ray
from ray.job_submission import JobStatus # , JobStatusInfo
from ray.job_submission import JobSubmissionClient
def get_head_worker_ip():
"""Get head worker IP address."""
pass
client = JobSubmissionClient("auto")
run_name = "yejingxin-maxtext-llama-032605"
job_ids = []
num_slices = 2 # 2 x v5p-128
worker0_ip_address = get_head_worker_ip() # to be defined
mxla_coordinator_address = f"{worker0_ip_address}:8080"
for slice_index in range(num_slices):
for job_index in range(16):
job_id = client.submit_job(
entrypoint="python3 -c 'import jax; print(jax.device_count())'",
runtime_env={
"working_dir": ".",
"env_vars": {
"MEGASCALE_COORDINATOR_ADDRESS": mxla_coordinator_address,
"MEGASCALE_SLICE_ID": f"{slice_index}",
"MEGASCALE_NUM_SLICES": f"{num_slices}",
"TF_CPP_MIN_LOG_LEVEL": "0",
"TPU_STDERR_LOG_LEVEL": "0",
"TPU_MIN_LOG_LEVEL": "0",
},
"pip": [
(
"-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html"
),
(
"-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html"
),
(
"-f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
),
],
},
entrypoint_resources={f"yejingxin-tpu-slice_{slice_index}": 1},
)
job_ids.append(job_id)
print(job_ids)
print(f"ray job logs -f {job_ids[0]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment