Created
February 10, 2024 16:53
-
-
Save Dref360/89d101fb296ee12ef24d1b923fa02d0b to your computer and use it in GitHub Desktop.
Example of uncertainty estimation using Baal on Speech Recognition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Wav2Vec in Baal | |
from datasets import load_dataset | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments | |
from baal.active.heuristics import BALD | |
from baal.bayesian.dropout import patch_module | |
from baal.transformers_trainer_wrapper import BaalTransformersTrainer | |
# load model and tokenizer | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
# load dummy dataset and read soundfiles | |
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") | |
# Preprocess the audio and set format to torch. | |
ds_processed = ( | |
ds.map( | |
lambda u: {k: v[0] for k, v in processor(u["audio"]["array"], return_tensors="pt", padding="longest").items()}) | |
.remove_columns(ds.column_names) | |
.with_format("torch")) | |
def uncertainty_estimation(ds_processed): | |
patched_model = patch_module(model) # Replace dropout layers | |
wrapper = BaalTransformersTrainer(model=patched_model, args=TrainingArguments('/tmp', per_device_eval_batch_size=1)) | |
predictions_generator = wrapper.predict_on_dataset_generator(ds_processed, | |
iterations=20) # 20 MC-Dropout iterations. | |
first_pred = next(predictions_generator) # WARNING: Shape is [Batch Size, Num Classes, Num Tokens, Num Iteration] | |
uncertainty = BALD(reduction='mean').get_uncertainties_generator(predictions_generator) | |
return uncertainty | |
uncertainty_estimation(ds_processed.select([1, 2, 3, 4, 5])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment