Created
January 18, 2024 02:33
-
-
Save Fhrozen/5cc0366fb40fc08f7358d5287ed69435 to your computer and use it in GitHub Desktop.
Script for tracing TDNN from speechbrain
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import torchaudio | |
import torch | |
from torch import nn | |
from speechbrain.inference.classifiers import EncoderClassifier | |
from matplotlib import pyplot as plt | |
from speechbrain.lobes.models.Xvector import Xvector | |
from speechbrain.lobes.features import Fbank | |
from speechbrain.processing.features import InputNormalization | |
class Extractor(nn.Module): | |
model_dict = [ | |
"mean_var_norm", | |
"compute_features", | |
"embedding_model", | |
"mean_var_norm_emb", | |
] | |
def __init__(self, model_path, n_mels=24, device="cpu"): | |
super().__init__() | |
self.device = device | |
self.compute_features = Fbank(n_mels=n_mels) | |
self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False) | |
self.embedding_model = Xvector( | |
in_channels = n_mels, | |
activation = torch.nn.LeakyReLU, | |
tdnn_blocks = 5, | |
tdnn_channels = [512, 512, 512, 512, 1500], | |
tdnn_kernel_sizes = [5, 3, 3, 1, 1], | |
tdnn_dilations = [1, 2, 3, 1, 1], | |
lin_neurons = 512, | |
) | |
self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False) | |
for mod_name in self.model_dict: | |
filename = os.path.join(model_path, f"{mod_name}.ckpt") | |
module = getattr(self, mod_name) | |
if os.path.exists(filename): | |
if hasattr(module, "_load"): | |
print(f"Load: {filename}") | |
module._load(filename) | |
else: | |
print(f"Load State Dict: {filename}") | |
module.load_state_dict(torch.load(filename)) | |
module.to(self.device) | |
# self.mean_var_norm_emb._load(os.path.join(model_path, "mean_var_norm_emb.ckpt")) | |
# self.embedding_model.load_state_dict( | |
# torch.load(os.path.join(model_path, "embedding_model.ckpt")) | |
# ) | |
@torch.no_grad() | |
def forward(self, wavs, wav_lens = None, normalize=False): | |
# Manage single waveforms in input | |
if len(wavs.shape) == 1: | |
wavs = wavs.unsqueeze(0) | |
# Assign full length if wav_lens is not assigned | |
if wav_lens is None: | |
wav_lens = torch.ones(wavs.shape[0], device=self.device) | |
# Storing waveform in the specified device | |
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) | |
wavs = wavs.float() | |
# Computing features and embeddings | |
feats = self.compute_features(wavs) | |
feats = self.mean_var_norm(feats, wav_lens) | |
embeddings = self.embedding_model(feats, wav_lens) | |
if normalize: | |
embeddings = self.mean_var_norm_emb( | |
embeddings, torch.ones(embeddings.shape[0], device=self.device) | |
) | |
return embeddings | |
MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb" | |
classifier = EncoderClassifier.from_hparams( | |
source="speechbrain/spkrec-xvect-voxceleb", | |
savedir=MODEL_PATH | |
) | |
signal, fs = torchaudio.load('/export/corpus01/LibriSpeech/dev-clean/1272/128104/1272-128104-0000.flac') | |
embeddings_class = classifier.encode_batch(signal).cpu().squeeze() | |
device = "cuda" | |
extractor = Extractor(MODEL_PATH, device=device) | |
for k, p in extractor.named_parameters(): | |
p.requires_grad = False | |
extractor.eval() | |
embeddings_x = extractor(signal).cpu().squeeze() | |
# Similarity Evaluation | |
cos = nn.CosineSimilarity(dim=0, eps=1e-6) | |
output = cos(embeddings_x, embeddings_class) | |
diff = embeddings_class - embeddings_x | |
# print(embeddings_class[:10], embeddings_x[:10]) | |
print(output, diff.abs().sum()) | |
# Tracing | |
traced_model = torch.jit.trace(extractor, signal) | |
torch.jit.save(traced_model, f"model_{device}.pt") | |
embeddings_t = traced_model(signal).squeeze() | |
output1 = cos(embeddings_class.to(device), embeddings_t) | |
output2 = cos(embeddings_x.to(device), embeddings_t) | |
print(embeddings_t.shape, output1, output2) | |
model = torch.jit.load(f"model_{device}.pt") | |
emb_m = model(signal).squeeze() | |
print(model.code) | |
print(cos(embeddings_x.to(device), emb_m)) | |
print(emb_m) | |
# Also in the speechbrain package, a minor code replace: | |
# speechbrain.nnet.pooling.py L296: | |
# commented: #actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) | |
# added: actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment