mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
fix trubrics lint issue (#11202)
This commit is contained in:
parent
b738ccd91e
commit
8cd18a48e4
@ -2,10 +2,44 @@ import os
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.adapters.openai import convert_message_to_dict
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
message_dict: Dict[str, Any]
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
|
# If function call only, content is None not empty string
|
||||||
|
if message_dict["content"] == "":
|
||||||
|
message_dict["content"] = None
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
message_dict = {
|
||||||
|
"role": "function",
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.name,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Got unknown type {message}")
|
||||||
|
if "name" in message.additional_kwargs:
|
||||||
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
class TrubricsCallbackHandler(BaseCallbackHandler):
|
class TrubricsCallbackHandler(BaseCallbackHandler):
|
||||||
@ -25,7 +59,7 @@ class TrubricsCallbackHandler(BaseCallbackHandler):
|
|||||||
project: str = "default",
|
project: str = "default",
|
||||||
email: Optional[str] = None,
|
email: Optional[str] = None,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
@ -56,9 +90,9 @@ class TrubricsCallbackHandler(BaseCallbackHandler):
|
|||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
**kwargs: Any
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.messages = [convert_message_to_dict(message) for message in messages[0]]
|
self.messages = [_convert_message_to_dict(message) for message in messages[0]]
|
||||||
self.prompt = self.messages[-1]["content"]
|
self.prompt = self.messages[-1]["content"]
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user