Created
September 16, 2020 06:11
-
-
Save jmp84/b98dfc5788b530d594220a552e4fdb27 to your computer and use it in GitHub Desktop.
TorchScript MT model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import torch | |
from fairseq.checkpoint_utils import load_model_ensemble_and_task | |
from fairseq.sequence_generator import SequenceGenerator | |
def get_args(): | |
parser = argparse.ArgumentParser( | |
description="Script to convert a PyTorch model to a TorchScript model." | |
) | |
parser.add_argument( | |
"--input-model", type=str, required=True, help="Path to the PyTorch model." | |
) | |
parser.add_argument( | |
"--output-model", | |
type=str, | |
required=True, | |
help="Path to the output TorchScript model.", | |
) | |
parser.add_argument( | |
"--beam-size", | |
type=int, | |
required=True, | |
help="Beam size for the sequence generator.", | |
) | |
parser.add_argument("--quantize", action="store_true", help="Apply quantization.") | |
return parser.parse_args() | |
def main(): | |
args = get_args() | |
logging.info("Loading model...") | |
model = load_model_ensemble_and_task([args.input_model])[0][0] | |
model.eval() | |
logging.info("Model loaded.") | |
model_dict = model.decoder.dictionary | |
generator = SequenceGenerator([model], model_dict, beam_size=args.beam_size) | |
if args.quantize: | |
generator = torch.quantization.quantize_dynamic( | |
generator, {torch.nn.Linear}, dtype=torch.qint8, inplace=True | |
) | |
logging.info("TorchScripting...") | |
scripted_generator = torch.jit.script(generator) | |
logging.info("Saving TorchScript model...") | |
scripted_generator.save(args.output_model) | |
logging.info("Done!") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment