Last active
February 19, 2021 09:58
-
-
Save jbcdnr/50d416dec405d88f576ba497c418b04e to your computer and use it in GitHub Desktop.
Record input and output of any PyTorch layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# Quick start | |
wget -O recorder.py https://gist.github.com/jbcdnr/50d416dec405d88f576ba497c418b04e/raw/ | |
# Small example | |
from recorder import trace_layers | |
import torchvision | |
resnet = torchvision.models.resnet18() | |
with trace_layers(resnet, "layer1[0].conv1", "layer1[0].conv2") as recorded_values: | |
resnet(torch.randn(1, 3, 224, 224)) | |
recorded_values["layer1[0].conv1"].output | |
""" | |
from contextlib import contextmanager | |
from collections import namedtuple, OrderedDict | |
import torch.nn as nn | |
Record = namedtuple("Record", "input output") | |
class RecorderLayer(nn.Module): | |
def __init__(self, original_layer, record_callback): | |
super().__init__() | |
self.original_layer = original_layer | |
self.record_callback = record_callback | |
def forward(self, *args, **kwargs): | |
input = (args, kwargs) | |
output = self.original_layer(*args, **kwargs) | |
self.record_callback(Record(input, output)) | |
return output | |
def get_layer(model, name): | |
return eval(f"model.{name}") | |
def set_layer(model, name, new_layer): | |
return exec(f"model.{name} = new_layer") | |
@contextmanager | |
def trace_layers(model, *layer_names): | |
recorded_values = OrderedDict([(l, None) for l in layer_names]) | |
def record_callback(name): | |
def callback(value): | |
assert recorded_values[name] is None, f"'{name}' already recorded." | |
recorded_values[name] = value | |
return callback | |
try: | |
# Replace the specified layers with RecorderLayers | |
for layer_name in layer_names: | |
layer = get_layer(model, layer_name) | |
set_layer(model, layer_name, RecorderLayer(layer, record_callback(layer_name))) | |
yield recorded_values | |
finally: | |
# Set back the original layers | |
for layer_name in layer_names: | |
original_layer = get_layer(model, layer_name).original_layer | |
set_layer(model, layer_name, original_layer) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment