Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Created November 7, 2023 17:19
Show Gist options
  • Save yejingxin/ee071b39b04e0ef5fdc4131034f65c75 to your computer and use it in GitHub Desktop.
Save yejingxin/ee071b39b04e0ef5fdc4131034f65c75 to your computer and use it in GitHub Desktop.
Debug ckpt loading time
import unittest
from absl import logging
import time
import checkpoint_utils
import time
import orbax.checkpoint as ocp
class TestStringMethods(unittest.TestCase):
def test_upper(self):
ckpt_dir = "gs://yejingxin-us-central2-public/wendy/ckpt2"
#"gs://yejingxin-us-central2/rl/102201/ckpt"
starttime = time.time()
print(f'ckpt dir {ckpt_dir}')
_ckpt_manager = checkpoint_utils.get_ckpt_manager(ckpt_dir, save_interval_steps=100, create=False, use_async=False)
print(f'manager creation spent {time.time() - starttime}s')
starttime = time.time()
all_steps = _ckpt_manager.all_steps(read=True)
print(f'find all ckpts spent {time.time() - starttime}s')
starttime = time.time()
latest_step = max(all_steps) if all_steps else None
if latest_step is None:
latest_step = 0
restored = _ckpt_manager.restore(latest_step)
restored_params = restored['save_state']['state']
print(f'restore ckpt takes spent {time.time() - starttime}s')
print('loading successfully')
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment