diff --git a/libs/langchain/langchain/callbacks/trubrics_callback.py b/libs/langchain/langchain/callbacks/trubrics_callback.py index df9889b30d..168793aeed 100644 --- a/libs/langchain/langchain/callbacks/trubrics_callback.py +++ b/libs/langchain/langchain/callbacks/trubrics_callback.py @@ -2,10 +2,44 @@ import os from typing import Any, Dict, List, Optional from uuid import UUID -from langchain.adapters.openai import convert_message_to_dict from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult -from langchain.schema.messages import BaseMessage +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + # If function call only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise TypeError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict class TrubricsCallbackHandler(BaseCallbackHandler): @@ -25,7 +59,7 @@ class TrubricsCallbackHandler(BaseCallbackHandler): project: str = "default", email: Optional[str] = None, password: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__() try: @@ -56,9 +90,9 @@ class TrubricsCallbackHandler(BaseCallbackHandler): self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], - **kwargs: Any + **kwargs: Any, ) -> None: - self.messages = [convert_message_to_dict(message) for message in messages[0]] + self.messages = [_convert_message_to_dict(message) for message in messages[0]] self.prompt = self.messages[-1]["content"] def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None: