Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Last active May 31, 2023 23:26
Show Gist options
  • Save yejingxin/7896608a1c23632bd463f4d0d1d818e6 to your computer and use it in GitHub Desktop.
Save yejingxin/7896608a1c23632bd463f4d0d1d818e6 to your computer and use it in GitHub Desktop.
Soft Slicing
  • create tpu
gcloud alpha compute tpus tpu-vm create yejingxin-tpu-v2     --zone us-central2-b     --project tpu-prod-env-one-vm     --accelerator-type v4-32     --version tpu-vm-v4-base
  • install jax
gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=all --command="pip install --upgrade 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" 
  • find ip and topology correspondence
 gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=all --command="python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices())' "
t1v-n-3938e5ad-w-2 10.130.0.2 [TpuDevice(id=12, process_index=3, coords=(0,0,3), core_on_chip=0), TpuDevice(id=13, process_index=3, coords=(1,0,3), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(0,1,3), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(1,1,3), core_on_chip=0)]
t1v-n-3938e5ad-w-0 10.130.0.22 [TpuDevice(id=4, process_index=1, coords=(0,0,1), core_on_chip=0), TpuDevice(id=5, process_index=1, coords=(1,0,1), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(0,1,1), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(1,1,1), core_on_chip=0)]
t1v-n-3938e5ad-w-1 10.130.0.28 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
t1v-n-3938e5ad-w-3 10.130.0.10 [TpuDevice(id=8, process_index=2, coords=(0,0,2), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,0,2), core_on_chip=0), TpuDevice(id=10, process_index=2, coords=(0,1,2), core_on_chip=0), TpuDevice(id=11, process_index=2, coords=(1,1,2), core_on_chip=0)]
  • Test split into four v4-8
gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=0 \
 --command="TPU_HOST_BOUNDS=1,1,1 TPU_WORKER_HOSTNAMES=10.130.0.22 TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices(), jax.device_count())' "

t1v-n-3938e5ad-w-0 10.130.0.22 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] 4
gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=1 \
 --command="TPU_HOST_BOUNDS=1,1,1 TPU_WORKER_HOSTNAMES=10.130.0.28 TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices(), jax.device_count())' "

t1v-n-3938e5ad-w-1 10.130.0.28 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] 4
gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=2 \
 --command="TPU_HOST_BOUNDS=1,1,1 TPU_WORKER_HOSTNAMES=10.130.0.2 TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices(), jax.device_count())' "

t1v-n-3938e5ad-w-2 10.130.0.2 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] 4
gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=3 \
 --command="TPU_HOST_BOUNDS=1,1,1 TPU_WORKER_HOSTNAMES=10.130.0.10 TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices(), jax.device_count())' "

t1v-n-3938e5ad-w-3 10.130.0.10 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] 4)] 4
  • Test split into two v4-16
gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=0 \
 --command="TPU_HOST_BOUNDS=1,1,2 TPU_WORKER_HOSTNAMES=10.130.0.22,10.130.0.28 TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices(), jax.device_count())' " & \
 gcloud compute tpus tpu-vm ssh yejingxin-tpu-v2 \
 --zone=us-central2-b --worker=1 \
 --command="TPU_HOST_BOUNDS=1,1,2 TPU_WORKER_HOSTNAMES=10.130.0.22,10.130.0.28 TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 python3 -c 'import jax; import socket; hostname = socket.gethostname(); ip_addr = socket.gethostbyname(hostname); print(hostname, ip_addr, jax.local_devices(), jax.device_count())' "

t1v-n-3938e5ad-w-1 10.130.0.28 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] 8
t1v-n-3938e5ad-w-0 10.130.0.22 [TpuDevice(id=4, process_index=1, coords=(0,0,1), core_on_chip=0), TpuDevice(id=5, process_index=1, coords=(1,0,1), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(0,1,1), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(1,1,1), core_on_chip=0)] 8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment