Last active
July 19, 2021 17:50
-
-
Save jayleicn/c19123454f9d310ed6029115b3dc1ac5 to your computer and use it in GitHub Desktop.
Logging other number in tensorbaord for MMF.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in mmf/mmf/trainers/mmf_trainer.py | |
# add line below | |
from user import all_user_callbacks | |
# add two lines in | |
@registry.register_trainer("mmf") | |
class MMFTrainer( | |
TrainerCallbackHookMixin, | |
TrainerTrainingLoopMixin, | |
TrainerDeviceMixin, | |
TrainerEvaluationLoopMixin, | |
TrainerProfilingMixin, | |
BaseTrainer, | |
): | |
def __init__(self, config: DictConfig): | |
super().__init__(config) | |
def configure_callbacks(self): | |
... | |
# user callbacks # add by @Jie | |
for callback_cls in all_user_callbacks: | |
self.callbacks.append(callback_cls(self.config, self)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import torch | |
import math | |
from mmf.trainers.callbacks.base import Callback | |
from mmf.utils.configuration import get_mmf_env | |
from mmf.utils.logger import ( | |
TensorboardLogger, | |
calculate_time_left, | |
setup_output_folder, | |
summarize_report, | |
) | |
from mmf.utils.timer import Timer | |
from mmf.utils.general import print_model_parameters | |
logger = logging.getLogger(__name__) | |
class VseInfCallback(Callback): | |
"""A user customized callback class for VSE Inf training.""" | |
def __init__(self, config, trainer): | |
""" | |
Attr: | |
config(mmf_typings.DictConfig): Config for the callback | |
trainer(Type[BaseTrainer]): Trainer object | |
""" | |
super().__init__(config, trainer) | |
if hasattr(config.training, "vse_inf_training"): | |
vse_cfg = config.training.vse_inf_training | |
self.enable = vse_cfg.enable | |
self.warmup_steps = vse_cfg.warmup_steps | |
# freeze_cnn and max_violation only works when enable is True | |
self.freeze_img_backbone = vse_cfg.freeze_img_backbone # default True | |
self.freeze_blocks_after_warmup = vse_cfg.freeze_blocks_after_warmup # default 0 | |
self.max_violation = vse_cfg.max_violation # default False | |
logger.info(f"VseInfCallback: {repr(vse_cfg)}") | |
else: | |
self.enable = False | |
def on_update_start(self, **kwargs): | |
if self.enable: | |
self.apply_vse_training_scheme() | |
def apply_vse_training_scheme(self): | |
model = self.trainer.model | |
if hasattr(model, "module"): # DDP | |
model = model.module | |
current_iteration = self.trainer.current_iteration | |
verbose = current_iteration in [0, self.warmup_steps] | |
# loss | |
model.set_max_violation(True, verbose=verbose) | |
if current_iteration < self.warmup_steps: | |
# reset for warmup phase. self.max_violation is False by default | |
model.set_max_violation(self.max_violation, verbose=verbose) | |
# freeze/unfreeze parameters | |
if current_iteration < self.warmup_steps and self.freeze_img_backbone: | |
model.freeze_img_backbone(verbose) # freeze all | |
else: | |
model.unfreeze_img_backbone( | |
self.freeze_blocks_after_warmup, verbose) # 0 == unfreeze all | |
if verbose: | |
print_model_parameters(model) | |
class TensorBoardLoggingCallback(Callback): | |
"""A customized callback class for logging other parameters.""" | |
def __init__(self, config, trainer): | |
""" | |
Attr: | |
config(mmf_typings.DictConfig): Config for the callback | |
trainer(Type[BaseTrainer]): Trainer object | |
""" | |
super().__init__(config, trainer) | |
logger.info(f"{str(self)}: {repr(config.training.vse_inf_training)}") | |
def on_update_end(self, **kwargs): | |
trainer = self.trainer | |
num_updates = trainer.num_updates | |
tb_writer = trainer.logistics_callback.tb_writer | |
# Log Any other numbers here | |
if num_updates % trainer.logistics_callback.log_interval == 0: | |
# learning rate | |
lr = trainer.optimizer.param_groups[0]['lr'] | |
tb_writer.add_scalar("train/lr", lr, num_updates) | |
# trainable temperature | |
_model = trainer.model | |
_model = _model.module if hasattr(_model, "module") else _model | |
if hasattr(_model, "learned_temperature"): | |
t = float(_model.learned_temperature.cpu()) | |
t = 1. / math.exp(-t) # CLIP's implementation | |
tb_writer.add_scalar("train/temperature", t, num_updates) | |
all_user_callbacks = [VseInfCallback, TensorBoardLoggingCallback] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Step 1: Create a new file
user.py
and write a new callback classTensorBoardLoggingCallback
, This class will have access totrainer
(Step 4), and thus the tensorboard viaself.trainer.logistics_callback.tb_writer
. Thus, we can log number easily with it. For example, in my user case, I loggedlearning rate
and one of my model parameterlearned_temperature
.Step 2: Add the callback to the
all_user_callbacks
listStep 3: Import
all_user_callbacks
inmmf/trainers/mmf_trainer.py
fileStep 4: Add two lines in the
configure_callbacks
method, as in https://gist.github.com/jayleicn/c19123454f9d310ed6029115b3dc1ac5#file-mmf_trainer-py-L20-L22. This will register the callbacks into the trainer.Done!
BTW, using this logic, you can also do some other customizations, like control which part of the model to freeze at different steps, as in this callback: https://gist.github.com/jayleicn/c19123454f9d310ed6029115b3dc1ac5#file-user-py-L21