import json import logging from typing import ( Any, Dict, List, Mapping, Optional, Tuple, ) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel from langchain.schema import ( ChatGeneration, ChatResult, ) from langchain.schema.messages import ( AIMessage, BaseMessage, ChatMessage, FunctionMessage, HumanMessage, SystemMessage, ) logger = logging.getLogger(__name__) def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] if role == "user": return HumanMessage(content=_dict["content"]) elif role == "assistant": # Fix for azure # Also OpenAI returns None for tool invocations content = _dict.get("content") or "" if _dict.get("function_call"): _dict["function_call"]["arguments"] = json.dumps( _dict["function_call"]["arguments"] ) additional_kwargs = {"function_call": dict(_dict["function_call"])} else: additional_kwargs = {} return AIMessage(content=content, additional_kwargs=additional_kwargs) elif role == "system": return SystemMessage(content=_dict["content"]) elif role == "function": return FunctionMessage(content=_dict["content"], name=_dict["name"]) else: return ChatMessage(content=_dict["content"], role=role) def _convert_message_to_dict(message: BaseMessage) -> dict: if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} if "function_call" in message.additional_kwargs: message_dict["function_call"] = message.additional_kwargs["function_call"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, FunctionMessage): message_dict = { "role": "function", "content": message.content, "name": message.name, } else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict class ChatLlamaAPI(BaseChatModel): client: Any #: :meta private: def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) _params = {"messages": message_dicts} final_params = {**params, **kwargs, **_params} response = self.client.run(final_params).json() return self._create_chat_result(response) def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = dict(self._client_params) if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: generations = [] for res in response["choices"]: message = _convert_dict_to_message(res["message"]) gen = ChatGeneration( message=message, generation_info=dict(finish_reason=res.get("finish_reason")), ) generations.append(gen) return ChatResult(generations=generations) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: raise NotImplementedError @property def _client_params(self) -> Mapping[str, Any]: """Get the parameters used for the client.""" return {} @property def _llm_type(self) -> str: """Return type of chat model.""" return "llama-api"