Created
January 2, 2025 01:47
-
-
Save dbreunig/d9054e77a84cb151904052d4d7d8a13c to your computer and use it in GitHub Desktop.
Generating synthetic Q&A data from a provided podcast, using DSPy, Ollama, and Llama 3.3.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from faster_whisper import WhisperModel | |
import dspy | |
from typing import List | |
import json | |
# Replace as you see fit... | |
whisper_model = "medium.en" | |
podcast_path = "audio/the_invention_of_photography.mp3" | |
# We're using the In Our Time episode about the invention of photography. | |
# https://www.bbc.co.uk/programmes/b07j699g | |
# | |
# Transcribe the podcast using Whisper | |
# | |
model = WhisperModel(whisper_model, device="cpu", compute_type="int8") | |
segments, info = model.transcribe(podcast_path, beam_size=5) | |
transcript = "" | |
for segment in segments: | |
transcript += segment.text + "\n" | |
# | |
# Set up DSPy & the extraction signature | |
# | |
# We're using Llama 3.3 for this example, running on Ollama | |
lm = dspy.LM('ollama/llama3.3', api_base='http://localhost:11434') | |
dspy.configure(lm=lm) | |
# Our sig has some descriptions to focus the question extraction | |
class IdentifyQuestions(dspy.Signature): | |
"""Identify questions from a transcript""" | |
transcript: str = dspy.InputField() | |
questions: List[str] = dspy.OutputField(desc="List of all questions in the transcript about the main topic of the podcast, excluding questions about the podcast itself. Do not miss any questions.") | |
answers: List[str] = dspy.OutputField(desc="List of answers for each identified question") | |
identifyQuestions = dspy.ChainOfThought(IdentifyQuestions) | |
# | |
# Extract & format the questions | |
# | |
output = identifyQuestions(transcript=transcript) | |
qa_pairs = [{'question': q, 'answer': a} for q, a in zip(output.questions, output.answers)] | |
# Write to JSONL | |
with open('qa_pairs.jsonl', 'w') as f: | |
for qa_pair in qa_pairs: | |
f.write(json.dumps(qa_pair) + '\n') | |
# | |
# Optional: Evaluate the output with another signature | |
# | |
class ValidateQuestionAndAnswer(dspy.Signature): | |
"""Validate a question and answer pair as accurate using a podcast transcript""" | |
transcript: str = dspy.InputField() | |
question: str = dspy.InputField() | |
answer: str = dspy.InputField() | |
valid: bool = dspy.OutputField(desc="Whether the answer correctly answers the question using information from the transcript") | |
validator = dspy.ChainOfThought(ValidateQuestionAndAnswer) | |
# The metric function | |
def validate_question_and_answer(transcript: str, question: str, answer: str) -> bool: | |
# You can change the LLM here to a bigger one if you're optimizing | |
with dspy.context(lm=dspy.LM('ollama/llama3.3', api_base='http://localhost:11434', cache=False)): | |
validation = validator(transcript=transcript, question=question, answer=answer) | |
return validation.valid | |
# Check that it works | |
for qa_pair in qa_pairs: | |
print(validate_question_and_answer(transcript, qa_pair['question'], qa_pair['answer'])) | |
print(qa_pair['question']) | |
print(qa_pair['answer']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment