Skip to content

Instantly share code, notes, and snippets.

@arbuckle
Last active May 9, 2025 01:06
Show Gist options
  • Save arbuckle/d41cb5e25ccc588f4a98b8430eca40b5 to your computer and use it in GitHub Desktop.
Save arbuckle/d41cb5e25ccc588f4a98b8430eca40b5 to your computer and use it in GitHub Desktop.
import json
from src.diart.pipelines import VoiceActivityDetection, SpeakerDiarization, Transcription, TranscriptionConfig
from src.diart.sources import WebSocketAudioSource
from src.diart.inference import StreamingInference
from src.diart import utils
# 1) Create pipelines for Voice Activity Detection, Speaker Diarization, and Transcription
# By default, Transcription can also handle VAD internally if a segmentation model is provided.
# Here, we use SpeakerDiarization to identify speakers and Transcription to get text.
# Diarization pipeline (handles segmentation and speaker identification)
diarization_pipeline = SpeakerDiarization()
# ensure that transcription pipeline config has a segmentation model set.
config = TranscriptionConfig(
duration=0.5,
step=0.5,
latency=0.5
)
transcription_pipeline = Transcription(config)
# 2) Create an audio source from a WebSocket endpoint
# The WebSocketAudioSource will receive audio chunks from a websocket server listening on host:port.
source = WebSocketAudioSource(diarization_pipeline.config.sample_rate, host="0.0.0.0", port=7007)
# 3) Set timestamp shift for diarization if needed. For streaming, typically no shift is required.
diarization_pipeline.set_timestamp_shift(0)
# 4) Define a callback hook that will receive the diarization predictions and corresponding audio chunks.
def on_diarization(prediction_with_audio):
"""
Called each time the diarization pipeline produces an output.
prediction_with_audio is a tuple: (Annotation, SlidingWindowFeature)
Annotation: speaker-labeled segments
SlidingWindowFeature: corresponding audio chunk
"""
annotation, wav = prediction_with_audio
# Apply transcription to the same audio chunk to get text
# Transcription pipeline returns a list of (text, waveform) tuples for each chunk
try:
transcription_results = transcription_pipeline([wav]) # returns [(text, wav)]
text = transcription_results[0][0]
except Exception as e:
text = "error. %s" % e
# Now we have speaker segmentation (annotation) and text from the same chunk.
# If you want to combine speaker labels and text, you can align them. For simplicity, just print them out:
print("Diarization result:", annotation)
print("Transcription result:", text)
# Send the transcription result back through the websocket if desired
return source.send(json.dumps({"annotation": annotation.__str__(), "text": text}))
# 5) Create a StreamingInference that runs the diarization pipeline continuously
inference = StreamingInference(diarization_pipeline, source)
inference.attach_hooks(on_diarization)
# 6) Start streaming inference
prediction = inference()
# The script will continuously process incoming audio, run diarization, then transcription,
# and print/send results until the audio source is closed or interrupted.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment