Created
October 25, 2024 22:59
-
-
Save yejingxin/890124ede1917b2be62792e640d2e8b9 to your computer and use it in GitHub Desktop.
Pytorch Autocheck with SIGTERM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
$ python3 autocheckpoint.py | |
train step=1 | |
train step=2 | |
train step=3 | |
train step=4 | |
train step=5 | |
train step=6 | |
train step=7 | |
train step=8 | |
Received SIGTERM. Termination requested. | |
Termination requested. Saving checkpoint... | |
emergency Checkpoint saved. Exiting. | |
checked saved ckpt: | |
ckpt load at recover_step=8 | |
model state: expected to be 36 | |
state_dict.get("model")=tensor([[36., 36.], | |
[36., 36.]]) | |
Exiting. | |
""" | |
import signal | |
import torch | |
import threading | |
import time | |
terminate_requested = threading.Event() | |
def save_checkpoint(model, step): | |
# Move tensors to CPU and save, reload back to check | |
ckpt_path = 'emergency.ckpt' | |
torch.save({'model': model.cpu(), 'step': step}, ckpt_path) | |
print("emergency Checkpoint saved. Exiting.") | |
state_dict = torch.load(ckpt_path) | |
recover_step = state_dict.get('step') | |
print('checked saved ckpt:') | |
print(f'ckpt load at {recover_step=}') | |
print(f'model state: expected to be {recover_step * (recover_step + 1) // 2}') | |
print(f'{state_dict.get("model")=}') | |
def handle_sigterm(signum, frame): | |
print("Received SIGTERM. Termination requested.") | |
terminate_requested.set() | |
signal.signal(signal.SIGTERM, handle_sigterm) | |
def model_train(step, model): | |
print(f'train {step=}') | |
model += step | |
time.sleep(2) | |
model = torch.zeros((2, 2)).to('cuda') | |
# Training loop | |
for step in range(1, 100): | |
model_train(step, model) | |
if terminate_requested.is_set(): | |
print("Termination requested. Saving checkpoint...") | |
save_checkpoint(model, step) | |
print("Exiting.") | |
exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment