fix trubrics lint issue (#11202)

This commit is contained in:
Bagatur 2023-09-28 18:07:50 -07:00 committed by GitHub
parent b738ccd91e
commit 8cd18a48e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,10 +2,44 @@ import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from langchain.adapters.openai import convert_message_to_dict
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult from langchain.schema import LLMResult
from langchain.schema.messages import BaseMessage from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class TrubricsCallbackHandler(BaseCallbackHandler): class TrubricsCallbackHandler(BaseCallbackHandler):
@ -25,7 +59,7 @@ class TrubricsCallbackHandler(BaseCallbackHandler):
project: str = "default", project: str = "default",
email: Optional[str] = None, email: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
super().__init__() super().__init__()
try: try:
@ -56,9 +90,9 @@ class TrubricsCallbackHandler(BaseCallbackHandler):
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
self.messages = [convert_message_to_dict(message) for message in messages[0]] self.messages = [_convert_message_to_dict(message) for message in messages[0]]
self.prompt = self.messages[-1]["content"] self.prompt = self.messages[-1]["content"]
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None: