Created
May 8, 2025 02:31
-
-
Save jamsea/805db92a401fcc7cc9114a8b174be8c1 to your computer and use it in GitHub Desktop.
Custom LLM with Pipecat
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomLLMService(AIService): | |
def __init__(self, backend: BackendBase): | |
super().__init__() | |
self.backend = backend | |
def create_context_aggregator( | |
self, | |
context: OpenAILLMContext, | |
*, | |
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), | |
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), | |
) -> OpenAIContextAggregatorPair: | |
"""Create an instance of OpenAIContextAggregatorPair from an | |
OpenAILLMContext. Constructor keyword arguments for both the user and | |
assistant aggregators can be provided. | |
Args: | |
context (OpenAILLMContext): The LLM context. | |
user_params (LLMUserAggregatorParams, optional): User aggregator parameters. | |
assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters. | |
Returns: | |
OpenAIContextAggregatorPair: A pair of context aggregators, one for | |
the user and one for the assistant, encapsulated in an | |
OpenAIContextAggregatorPair. | |
""" | |
context.set_llm_adapter(self.get_llm_adapter()) | |
user = OpenAIUserContextAggregator(context, params=user_params) | |
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params) | |
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) | |
async def process_frame(self, frame: Frame, direction: FrameDirection): | |
await super().process_frame(frame, direction) | |
context = None | |
if isinstance(frame, OpenAILLMContextFrame): | |
context: OpenAILLMContext = frame.context | |
elif isinstance(frame, LLMMessagesFrame): | |
context = OpenAILLMContext.from_messages(frame.messages) | |
elif isinstance(frame, LLMUpdateSettingsFrame): | |
await self._update_settings(frame.settings) | |
else: | |
await self.push_frame(frame, direction) | |
if context: | |
try: | |
await self.push_frame(LLMFullResponseStartFrame()) | |
await self.start_processing_metrics() | |
# await self._process_context(context) | |
msgs = [] | |
for contmsg in context.messages: | |
msgs.append( | |
LlmMessage( | |
role=contmsg["role"], | |
content=contmsg["content"], | |
) | |
) | |
resp = await self.backend.get_resp( | |
msgs, | |
{ | |
"conversation_id": "fake_conversation_id", | |
"user_id": "fake_user_id", | |
}, | |
) | |
context.add_messages(resp.msgs) | |
await self.push_frame(LLMTextFrame(resp.content)) | |
except httpx.TimeoutException: | |
await self._call_event_handler("on_completion_timeout") | |
finally: | |
await self.stop_processing_metrics() | |
await self.push_frame(LLMFullResponseEndFrame()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment