Created
November 7, 2023 17:19
-
-
Save yejingxin/ee071b39b04e0ef5fdc4131034f65c75 to your computer and use it in GitHub Desktop.
Debug ckpt loading time
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from absl import logging | |
import time | |
import checkpoint_utils | |
import time | |
import orbax.checkpoint as ocp | |
class TestStringMethods(unittest.TestCase): | |
def test_upper(self): | |
ckpt_dir = "gs://yejingxin-us-central2-public/wendy/ckpt2" | |
#"gs://yejingxin-us-central2/rl/102201/ckpt" | |
starttime = time.time() | |
print(f'ckpt dir {ckpt_dir}') | |
_ckpt_manager = checkpoint_utils.get_ckpt_manager(ckpt_dir, save_interval_steps=100, create=False, use_async=False) | |
print(f'manager creation spent {time.time() - starttime}s') | |
starttime = time.time() | |
all_steps = _ckpt_manager.all_steps(read=True) | |
print(f'find all ckpts spent {time.time() - starttime}s') | |
starttime = time.time() | |
latest_step = max(all_steps) if all_steps else None | |
if latest_step is None: | |
latest_step = 0 | |
restored = _ckpt_manager.restore(latest_step) | |
restored_params = restored['save_state']['state'] | |
print(f'restore ckpt takes spent {time.time() - starttime}s') | |
print('loading successfully') | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment