Bagatur/gateway chat (#8198)

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: dbczumar <corey.zumar@databricks.com>
This commit is contained in:
Bagatur 2023-07-24 12:17:00 -07:00 committed by GitHub
parent ae28568e2a
commit 1a7d8667c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 266 additions and 2 deletions

View File

@ -90,6 +90,31 @@ print(embeddings.embed_query("hello"))
print(embeddings.embed_documents(["hello"])) print(embeddings.embed_documents(["hello"]))
``` ```
## Chat Example
```python
from langchain.chat_models import ChatMLflowAIGateway
from langchain.schema import HumanMessage, SystemMessage
chat = ChatMLflowAIGateway(
gateway_uri="http://127.0.0.1:5000",
route="chat",
params={
"temperature": 0.1
}
)
messages = [
SystemMessage(
content="You are a helpful assistant that translates English to French."
),
HumanMessage(
content="Translate this sentence from English to French: I love programming."
),
]
print(chat(messages))
```
## Databricks MLflow AI Gateway ## Databricks MLflow AI Gateway
Databricks MLflow AI Gateway is in private preview. Databricks MLflow AI Gateway is in private preview.

View File

@ -4,6 +4,7 @@ from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.human import HumanInputChatModel
from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI from langchain.chat_models.vertexai import ChatVertexAI
@ -15,6 +16,7 @@ __all__ = [
"PromptLayerChatOpenAI", "PromptLayerChatOpenAI",
"ChatAnthropic", "ChatAnthropic",
"ChatGooglePalm", "ChatGooglePalm",
"ChatMLflowAIGateway",
"ChatVertexAI", "ChatVertexAI",
"JinaChat", "JinaChat",
"HumanInputChatModel", "HumanInputChatModel",

View File

@ -0,0 +1,204 @@
import asyncio
import logging
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
logger = logging.getLogger(__name__)
class ChatParams(BaseModel, extra=Extra.allow):
"""Parameters for the MLflow AI Gateway LLM."""
temperature: float = 0.0
candidate_count: int = 1
"""The number of candidates to return."""
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
class ChatMLflowAIGateway(BaseChatModel):
"""
Wrapper around chat LLMs in the MLflow AI Gateway.
To use, you should have the ``mlflow[gateway]`` python package installed.
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
Example:
.. code-block:: python
from langchain.chat_models import ChatMLflowAIGateway
chat = ChatMLflowAIGateway(
gateway_uri="<your-mlflow-ai-gateway-uri>",
route="<your-mlflow-ai-gateway-chat-route>",
params={
"temperature": 0.1
}
)
"""
def __init__(self, **kwargs: Any):
try:
import mlflow.gateway
except ImportError as e:
raise ImportError(
"Could not import `mlflow.gateway` module. "
"Please install it with `pip install mlflow[gateway]`."
) from e
super().__init__(**kwargs)
if self.gateway_uri:
mlflow.gateway.set_gateway_uri(self.gateway_uri)
route: str
gateway_uri: Optional[str] = None
params: Optional[ChatParams] = None
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"gateway_uri": self.gateway_uri,
"route": self.route,
**(self.params.dict() if self.params else {}),
}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
try:
import mlflow.gateway
except ImportError as e:
raise ImportError(
"Could not import `mlflow.gateway` module. "
"Please install it with `pip install mlflow[gateway]`."
) from e
message_dicts = [
ChatMLflowAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = mlflow.gateway.query(self.route, data=data)
return ChatMLflowAIGateway._create_chat_result(resp)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**self._default_params,
**super()._get_invocation_params(stop=stop, **kwargs),
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "mlflow-ai-gateway-chat"
@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict["content"]
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
else:
return ChatMessage(content=content, role=role)
@staticmethod
def _raise_functions_not_supported() -> None:
raise ValueError(
"Function messages are not supported by the MLflow AI Gateway. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by the MLflow AI Gateway. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
else:
raise ValueError(f"Got unknown message type: {message}")
if "function_call" in message.additional_kwargs:
ChatMLflowAIGateway._raise_functions_not_supported()
if message.additional_kwargs:
logger.warning(
"Additional message arguments are unsupported by MLflow AI Gateway "
" and will be ignored: %s",
message.additional_kwargs,
)
return message_dict
@staticmethod
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for candidate in response["candidates"]:
message = ChatMLflowAIGateway._convert_dict_to_message(candidate["message"])
message_metadata = candidate.get("metadata", {})
gen = ChatGeneration(
message=message,
generation_info=dict(message_metadata),
)
generations.append(gen)
response_metadata = response.get("metadata", {})
return ChatResult(generations=generations, llm_output=response_metadata)

View File

@ -13,7 +13,22 @@ def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
class MlflowAIGatewayEmbeddings(Embeddings, BaseModel): class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
"""MLflow AI Gateway Embeddings APIs.""" """
Wrapper around embeddings LLMs in the MLflow AI Gateway.
To use, you should have the ``mlflow[gateway]`` python package installed.
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
Example:
.. code-block:: python
from langchain.embeddings import MlflowAIGatewayEmbeddings
embeddings = MlflowAIGatewayEmbeddings(
gateway_uri="<your-mlflow-ai-gateway-uri>",
route="<your-mlflow-ai-gateway-embeddings-route>"
)
"""
route: str route: str
"""The route to use for the MLflow AI Gateway API.""" """The route to use for the MLflow AI Gateway API."""

View File

@ -19,7 +19,25 @@ class Params(BaseModel, extra=Extra.allow):
class MlflowAIGateway(LLM): class MlflowAIGateway(LLM):
"""The MLflow AI Gateway models.""" """
Wrapper around completions LLMs in the MLflow AI Gateway.
To use, you should have the ``mlflow[gateway]`` python package installed.
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
Example:
.. code-block:: python
from langchain.llms import MlflowAIGateway
completions = MlflowAIGateway(
gateway_uri="<your-mlflow-ai-gateway-uri>",
route="<your-mlflow-ai-gateway-completions-route>",
params={
"temperature": 0.1
}
)
"""
route: str route: str
gateway_uri: Optional[str] = None gateway_uri: Optional[str] = None