import logging from typing import Any, Dict, Iterator, List, Mapping, Optional, cast from urllib.parse import urlparse from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel from langchain_core.language_models.base import LanguageModelInput from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ChatMessage, ChatMessageChunk, FunctionMessage, HumanMessage, HumanMessageChunk, SystemMessage, SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import ( Field, PrivateAttr, ) from langchain_core.runnables import RunnableConfig logger = logging.getLogger(__name__) class ChatMlflow(BaseChatModel): """`MLflow` chat models API. To use, you should have the `mlflow[genai]` python package installed. For more information, see https://mlflow.org/docs/latest/llms/deployments. Example: .. code-block:: python from langchain_community.chat_models import ChatMlflow chat = ChatMlflow( target_uri="http://localhost:5000", endpoint="chat", temperature-0.1, ) """ endpoint: str """The endpoint to use.""" target_uri: str """The target URI to use.""" temperature: float = 0.0 """The sampling temperature.""" n: int = 1 """The number of completion choices to generate.""" stop: Optional[List[str]] = None """The stop sequence.""" max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" extra_params: dict = Field(default_factory=dict) """Any extra parameters to pass to the endpoint.""" _client: Any = PrivateAttr() def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._validate_uri() try: from mlflow.deployments import get_deploy_client self._client = get_deploy_client(self.target_uri) except ImportError as e: raise ImportError( "Failed to create the client. " f"Please run `pip install mlflow{self._mlflow_extras}` to install " "required dependencies." ) from e @property def _mlflow_extras(self) -> str: return "[genai]" def _validate_uri(self) -> None: if self.target_uri == "databricks": return allowed = ["http", "https", "databricks"] if urlparse(self.target_uri).scheme not in allowed: raise ValueError( f"Invalid target URI: {self.target_uri}. " f"The scheme must be one of {allowed}." ) @property def _default_params(self) -> Dict[str, Any]: params: Dict[str, Any] = { "target_uri": self.target_uri, "endpoint": self.endpoint, "temperature": self.temperature, "n": self.n, "stop": self.stop, "max_tokens": self.max_tokens, "extra_params": self.extra_params, } return params def _prepare_inputs( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Any, ) -> Dict[str, Any]: message_dicts = [ ChatMlflow._convert_message_to_dict(message) for message in messages ] data: Dict[str, Any] = { "messages": message_dicts, "temperature": self.temperature, "n": self.n, **self.extra_params, **kwargs, } if stop := self.stop or stop: data["stop"] = stop if self.max_tokens is not None: data["max_tokens"] = self.max_tokens return data def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: data = self._prepare_inputs( messages, stop, **kwargs, ) resp = self._client.predict(endpoint=self.endpoint, inputs=data) return ChatMlflow._create_chat_result(resp) def stream( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, **kwargs: Any, ) -> Iterator[BaseMessageChunk]: # We need to override `stream` to handle the case # that `self._client` does not implement `predict_stream` if not hasattr(self._client, "predict_stream"): # MLflow deployment client does not implement streaming, # so use default implementation yield cast( BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) ) else: yield from super().stream(input, config, stop=stop, **kwargs) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: data = self._prepare_inputs( messages, stop, **kwargs, ) # TODO: check if `_client.predict_stream` is available. chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data) for chunk in chunk_iter: choice = chunk["choices"][0] chunk = ChatMlflow._convert_delta_to_message_chunk(choice["delta"]) generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason if logprobs := choice.get("logprobs"): generation_info["logprobs"] = logprobs chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info or None ) if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) yield chunk @property def _identifying_params(self) -> Dict[str, Any]: return self._default_params def _get_invocation_params( self, stop: Optional[List[str]] = None, **kwargs: Any ) -> Dict[str, Any]: """Get the parameters used to invoke the model FOR THE CALLBACKS.""" return { **self._default_params, **super()._get_invocation_params(stop=stop, **kwargs), } @property def _llm_type(self) -> str: """Return type of chat model.""" return "mlflow-chat" @staticmethod def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] content = _dict["content"] if role == "user": return HumanMessage(content=content) elif role == "assistant": return AIMessage(content=content) elif role == "system": return SystemMessage(content=content) else: return ChatMessage(content=content, role=role) @staticmethod def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk: role = _dict["role"] content = _dict["content"] if role == "user": return HumanMessageChunk(content=content) elif role == "assistant": return AIMessageChunk(content=content) elif role == "system": return SystemMessageChunk(content=content) else: return ChatMessageChunk(content=content, role=role) @staticmethod def _raise_functions_not_supported() -> None: raise ValueError( "Function messages are not supported by Databricks. Please" " create a feature request at https://github.com/mlflow/mlflow/issues." ) @staticmethod def _convert_message_to_dict(message: BaseMessage) -> dict: if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, FunctionMessage): raise ValueError( "Function messages are not supported by Databricks. Please" " create a feature request at https://github.com/mlflow/mlflow/issues." ) else: raise ValueError(f"Got unknown message type: {message}") if "function_call" in message.additional_kwargs: ChatMlflow._raise_functions_not_supported() if message.additional_kwargs: logger.warning( "Additional message arguments are unsupported by Databricks" " and will be ignored: %s", message.additional_kwargs, ) return message_dict @staticmethod def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: generations = [] for choice in response["choices"]: message = ChatMlflow._convert_dict_to_message(choice["message"]) usage = choice.get("usage", {}) gen = ChatGeneration( message=message, generation_info=usage, ) generations.append(gen) usage = response.get("usage", {}) return ChatResult(generations=generations, llm_output=usage)