-
-
Save arbuckle/5163d435ba174ee3ae866e789fa03f23 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import numpy as np | |
from src.diart.pipelines import VoiceActivityDetection, SpeakerDiarization, Transcription, TranscriptionConfig, VoiceActivityDetectionConfig | |
from src.diart.sources import WebSocketAudioSource | |
from src.diart.inference import StreamingInference | |
from pyannote.core import SlidingWindow, SlidingWindowFeature | |
from pyannote.core import Annotation | |
# ------------------------------------- | |
# Configuration | |
# ------------------------------------- | |
# For example, we want a 2.5s transcription window with overlapping 0.5s steps. | |
config = TranscriptionConfig(duration=2.5, step=0.5, latency=0.5) | |
transcription_pipeline = Transcription(config) | |
# Speaker diarization pipeline | |
diarization_pipeline = SpeakerDiarization() | |
# Voice Activity Detection pipeline | |
c2 = VoiceActivityDetectionConfig(duration=2.5, step=0.5, latency=0.5) | |
vad_pipeline = VoiceActivityDetection(c2) | |
# Audio source (from websocket) | |
source = WebSocketAudioSource(diarization_pipeline.config.sample_rate, host="0.0.0.0", port=7007) | |
diarization_pipeline.set_timestamp_shift(0) | |
# ------------------------------------- | |
# Global buffers and state | |
# ------------------------------------- | |
accumulated_samples = np.empty((0, 1), dtype=np.float32) # Accumulate audio samples here | |
sample_rate = diarization_pipeline.config.sample_rate | |
required_samples = int(round(config.duration * sample_rate)) # Samples required for transcription window | |
# We will also accumulate speaker annotations. For simplicity, let's store them in a list. | |
# Each time we get a new chunk, we append its annotation. When we transcribe, we'll associate the text with the speakers. | |
accumulated_annotations = [] | |
# ------------------------------------- | |
# Helper functions | |
# ------------------------------------- | |
def append_chunk(wav): | |
"""Append the new chunk of audio samples to the accumulated buffer.""" | |
global accumulated_samples | |
# wav.data shape: (frames, channels) | |
accumulated_samples = np.vstack([accumulated_samples, wav.data]) | |
def is_voice_present(): | |
"""Check if voice is present in the accumulated audio using VAD.""" | |
global accumulated_samples | |
if len(accumulated_samples) != 40000: | |
return False | |
print(len(accumulated_samples)) | |
# VAD and transcription pipelines expect a SlidingWindowFeature-like input. | |
# Create a SlidingWindowFeature for the accumulated samples. | |
# Duration per frame = 1/sample_rate | |
resolution = 1 / sample_rate | |
window = SlidingWindow(start=0.0, duration=resolution, step=resolution) | |
wav_feature = SlidingWindowFeature(accumulated_samples, window) | |
# Run VAD pipeline on the entire accumulated waveform. | |
# The pipeline expects a list of waveforms: [wav_feature] | |
(vad_annotation, _) = vad_pipeline([wav_feature])[0] | |
# If the VAD annotation is not empty, we have speech. | |
return not vad_annotation.get_timeline().empty() | |
def run_transcription(): | |
"""Run transcription on the accumulated samples and return text.""" | |
global accumulated_samples | |
resolution = 1 / sample_rate | |
window = SlidingWindow(start=0.0, duration=resolution, step=resolution) | |
wav_feature = SlidingWindowFeature(accumulated_samples, window) | |
# Transcribe | |
results = transcription_pipeline([wav_feature]) # [(text, waveform)] | |
text = results[0][0] | |
return text | |
def build_response(text): | |
"""Build a JSON response with speakers and text segments from accumulated annotations.""" | |
# Extract speaker info: We have accumulated annotations for each incoming chunk. | |
# Merge them into one bigger annotation: | |
merged_annotation = Annotation(uri="session") | |
for ann in accumulated_annotations: | |
merged_annotation.update(ann) | |
# Convert annotation and text into a simple dict | |
# For simplicity, list each labeled segment as {start, end, speaker} | |
segments = [] | |
for segment, track, label in merged_annotation.itertracks(yield_label=True): | |
segments.append({ | |
"start": float(segment.start), | |
"end": float(segment.end), | |
"speaker": label | |
}) | |
response = { | |
"speakers": segments, | |
"text": text | |
} | |
return json.dumps(response) | |
def reset_accumulators(): | |
"""Clear the accumulated samples and annotations after processing.""" | |
global accumulated_samples, accumulated_annotations | |
accumulated_samples = np.empty((0, 1), dtype=np.float32) | |
accumulated_annotations.clear() | |
# ------------------------------------- | |
# Callback function | |
# ------------------------------------- | |
def on_diarization(prediction_with_audio): | |
""" | |
Called each time the diarization pipeline produces an output. | |
prediction_with_audio: (Annotation, SlidingWindowFeature) | |
""" | |
annotation, wav = prediction_with_audio | |
# Append this chunk to the accumulators | |
append_chunk(wav) | |
accumulated_annotations.append(annotation) | |
# Check if we have enough samples accumulated for transcription window | |
if len(accumulated_samples) >= required_samples: | |
# Check if voice is present in these samples | |
if is_voice_present(): | |
# Run transcription on the entire accumulated buffer | |
try: | |
text = run_transcription() | |
except Exception as e: | |
text = "error: %s" % e | |
# Build response | |
response = build_response(text) | |
# Send response back | |
source.send(response) | |
# Reset accumulators for the next chunk | |
reset_accumulators() | |
else: | |
# No voice, just clear accumulators and wait for next input | |
reset_accumulators() | |
# ------------------------------------- | |
# Run streaming inference | |
# ------------------------------------- | |
inference = StreamingInference(diarization_pipeline, source) | |
inference.attach_hooks(on_diarization) | |
prediction = inference() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment