Skip to content

Instantly share code, notes, and snippets.

@arbuckle
Created December 13, 2024 17:43
Show Gist options
  • Save arbuckle/5163d435ba174ee3ae866e789fa03f23 to your computer and use it in GitHub Desktop.
Save arbuckle/5163d435ba174ee3ae866e789fa03f23 to your computer and use it in GitHub Desktop.
import json
import numpy as np
from src.diart.pipelines import VoiceActivityDetection, SpeakerDiarization, Transcription, TranscriptionConfig, VoiceActivityDetectionConfig
from src.diart.sources import WebSocketAudioSource
from src.diart.inference import StreamingInference
from pyannote.core import SlidingWindow, SlidingWindowFeature
from pyannote.core import Annotation
# -------------------------------------
# Configuration
# -------------------------------------
# For example, we want a 2.5s transcription window with overlapping 0.5s steps.
config = TranscriptionConfig(duration=2.5, step=0.5, latency=0.5)
transcription_pipeline = Transcription(config)
# Speaker diarization pipeline
diarization_pipeline = SpeakerDiarization()
# Voice Activity Detection pipeline
c2 = VoiceActivityDetectionConfig(duration=2.5, step=0.5, latency=0.5)
vad_pipeline = VoiceActivityDetection(c2)
# Audio source (from websocket)
source = WebSocketAudioSource(diarization_pipeline.config.sample_rate, host="0.0.0.0", port=7007)
diarization_pipeline.set_timestamp_shift(0)
# -------------------------------------
# Global buffers and state
# -------------------------------------
accumulated_samples = np.empty((0, 1), dtype=np.float32) # Accumulate audio samples here
sample_rate = diarization_pipeline.config.sample_rate
required_samples = int(round(config.duration * sample_rate)) # Samples required for transcription window
# We will also accumulate speaker annotations. For simplicity, let's store them in a list.
# Each time we get a new chunk, we append its annotation. When we transcribe, we'll associate the text with the speakers.
accumulated_annotations = []
# -------------------------------------
# Helper functions
# -------------------------------------
def append_chunk(wav):
"""Append the new chunk of audio samples to the accumulated buffer."""
global accumulated_samples
# wav.data shape: (frames, channels)
accumulated_samples = np.vstack([accumulated_samples, wav.data])
def is_voice_present():
"""Check if voice is present in the accumulated audio using VAD."""
global accumulated_samples
if len(accumulated_samples) != 40000:
return False
print(len(accumulated_samples))
# VAD and transcription pipelines expect a SlidingWindowFeature-like input.
# Create a SlidingWindowFeature for the accumulated samples.
# Duration per frame = 1/sample_rate
resolution = 1 / sample_rate
window = SlidingWindow(start=0.0, duration=resolution, step=resolution)
wav_feature = SlidingWindowFeature(accumulated_samples, window)
# Run VAD pipeline on the entire accumulated waveform.
# The pipeline expects a list of waveforms: [wav_feature]
(vad_annotation, _) = vad_pipeline([wav_feature])[0]
# If the VAD annotation is not empty, we have speech.
return not vad_annotation.get_timeline().empty()
def run_transcription():
"""Run transcription on the accumulated samples and return text."""
global accumulated_samples
resolution = 1 / sample_rate
window = SlidingWindow(start=0.0, duration=resolution, step=resolution)
wav_feature = SlidingWindowFeature(accumulated_samples, window)
# Transcribe
results = transcription_pipeline([wav_feature]) # [(text, waveform)]
text = results[0][0]
return text
def build_response(text):
"""Build a JSON response with speakers and text segments from accumulated annotations."""
# Extract speaker info: We have accumulated annotations for each incoming chunk.
# Merge them into one bigger annotation:
merged_annotation = Annotation(uri="session")
for ann in accumulated_annotations:
merged_annotation.update(ann)
# Convert annotation and text into a simple dict
# For simplicity, list each labeled segment as {start, end, speaker}
segments = []
for segment, track, label in merged_annotation.itertracks(yield_label=True):
segments.append({
"start": float(segment.start),
"end": float(segment.end),
"speaker": label
})
response = {
"speakers": segments,
"text": text
}
return json.dumps(response)
def reset_accumulators():
"""Clear the accumulated samples and annotations after processing."""
global accumulated_samples, accumulated_annotations
accumulated_samples = np.empty((0, 1), dtype=np.float32)
accumulated_annotations.clear()
# -------------------------------------
# Callback function
# -------------------------------------
def on_diarization(prediction_with_audio):
"""
Called each time the diarization pipeline produces an output.
prediction_with_audio: (Annotation, SlidingWindowFeature)
"""
annotation, wav = prediction_with_audio
# Append this chunk to the accumulators
append_chunk(wav)
accumulated_annotations.append(annotation)
# Check if we have enough samples accumulated for transcription window
if len(accumulated_samples) >= required_samples:
# Check if voice is present in these samples
if is_voice_present():
# Run transcription on the entire accumulated buffer
try:
text = run_transcription()
except Exception as e:
text = "error: %s" % e
# Build response
response = build_response(text)
# Send response back
source.send(response)
# Reset accumulators for the next chunk
reset_accumulators()
else:
# No voice, just clear accumulators and wait for next input
reset_accumulators()
# -------------------------------------
# Run streaming inference
# -------------------------------------
inference = StreamingInference(diarization_pipeline, source)
inference.attach_hooks(on_diarization)
prediction = inference()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment