Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Created October 25, 2024 22:59
Show Gist options
  • Save yejingxin/890124ede1917b2be62792e640d2e8b9 to your computer and use it in GitHub Desktop.
Save yejingxin/890124ede1917b2be62792e640d2e8b9 to your computer and use it in GitHub Desktop.
Pytorch Autocheck with SIGTERM
"""
$ python3 autocheckpoint.py
train step=1
train step=2
train step=3
train step=4
train step=5
train step=6
train step=7
train step=8
Received SIGTERM. Termination requested.
Termination requested. Saving checkpoint...
emergency Checkpoint saved. Exiting.
checked saved ckpt:
ckpt load at recover_step=8
model state: expected to be 36
state_dict.get("model")=tensor([[36., 36.],
[36., 36.]])
Exiting.
"""
import signal
import torch
import threading
import time
terminate_requested = threading.Event()
def save_checkpoint(model, step):
# Move tensors to CPU and save, reload back to check
ckpt_path = 'emergency.ckpt'
torch.save({'model': model.cpu(), 'step': step}, ckpt_path)
print("emergency Checkpoint saved. Exiting.")
state_dict = torch.load(ckpt_path)
recover_step = state_dict.get('step')
print('checked saved ckpt:')
print(f'ckpt load at {recover_step=}')
print(f'model state: expected to be {recover_step * (recover_step + 1) // 2}')
print(f'{state_dict.get("model")=}')
def handle_sigterm(signum, frame):
print("Received SIGTERM. Termination requested.")
terminate_requested.set()
signal.signal(signal.SIGTERM, handle_sigterm)
def model_train(step, model):
print(f'train {step=}')
model += step
time.sleep(2)
model = torch.zeros((2, 2)).to('cuda')
# Training loop
for step in range(1, 100):
model_train(step, model)
if terminate_requested.is_set():
print("Termination requested. Saving checkpoint...")
save_checkpoint(model, step)
print("Exiting.")
exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment