mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Bagatur/gateway chat (#8198)
Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: dbczumar <corey.zumar@databricks.com>
This commit is contained in:
parent
ae28568e2a
commit
1a7d8667c8
@ -90,6 +90,31 @@ print(embeddings.embed_query("hello"))
|
|||||||
print(embeddings.embed_documents(["hello"]))
|
print(embeddings.embed_documents(["hello"]))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Chat Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.chat_models import ChatMLflowAIGateway
|
||||||
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
chat = ChatMLflowAIGateway(
|
||||||
|
gateway_uri="http://127.0.0.1:5000",
|
||||||
|
route="chat",
|
||||||
|
params={
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a helpful assistant that translates English to French."
|
||||||
|
),
|
||||||
|
HumanMessage(
|
||||||
|
content="Translate this sentence from English to French: I love programming."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
print(chat(messages))
|
||||||
|
```
|
||||||
|
|
||||||
## Databricks MLflow AI Gateway
|
## Databricks MLflow AI Gateway
|
||||||
|
|
||||||
Databricks MLflow AI Gateway is in private preview.
|
Databricks MLflow AI Gateway is in private preview.
|
||||||
|
@ -4,6 +4,7 @@ from langchain.chat_models.fake import FakeListChatModel
|
|||||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||||
from langchain.chat_models.human import HumanInputChatModel
|
from langchain.chat_models.human import HumanInputChatModel
|
||||||
from langchain.chat_models.jinachat import JinaChat
|
from langchain.chat_models.jinachat import JinaChat
|
||||||
|
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||||
from langchain.chat_models.vertexai import ChatVertexAI
|
from langchain.chat_models.vertexai import ChatVertexAI
|
||||||
@ -15,6 +16,7 @@ __all__ = [
|
|||||||
"PromptLayerChatOpenAI",
|
"PromptLayerChatOpenAI",
|
||||||
"ChatAnthropic",
|
"ChatAnthropic",
|
||||||
"ChatGooglePalm",
|
"ChatGooglePalm",
|
||||||
|
"ChatMLflowAIGateway",
|
||||||
"ChatVertexAI",
|
"ChatVertexAI",
|
||||||
"JinaChat",
|
"JinaChat",
|
||||||
"HumanInputChatModel",
|
"HumanInputChatModel",
|
||||||
|
204
libs/langchain/langchain/chat_models/mlflow_ai_gateway.py
Normal file
204
libs/langchain/langchain/chat_models/mlflow_ai_gateway.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatResult,
|
||||||
|
)
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatParams(BaseModel, extra=Extra.allow):
|
||||||
|
"""Parameters for the MLflow AI Gateway LLM."""
|
||||||
|
|
||||||
|
temperature: float = 0.0
|
||||||
|
candidate_count: int = 1
|
||||||
|
"""The number of candidates to return."""
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMLflowAIGateway(BaseChatModel):
|
||||||
|
"""
|
||||||
|
Wrapper around chat LLMs in the MLflow AI Gateway.
|
||||||
|
|
||||||
|
To use, you should have the ``mlflow[gateway]`` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatMLflowAIGateway
|
||||||
|
|
||||||
|
chat = ChatMLflowAIGateway(
|
||||||
|
gateway_uri="<your-mlflow-ai-gateway-uri>",
|
||||||
|
route="<your-mlflow-ai-gateway-chat-route>",
|
||||||
|
params={
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
try:
|
||||||
|
import mlflow.gateway
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import `mlflow.gateway` module. "
|
||||||
|
"Please install it with `pip install mlflow[gateway]`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if self.gateway_uri:
|
||||||
|
mlflow.gateway.set_gateway_uri(self.gateway_uri)
|
||||||
|
|
||||||
|
route: str
|
||||||
|
gateway_uri: Optional[str] = None
|
||||||
|
params: Optional[ChatParams] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"gateway_uri": self.gateway_uri,
|
||||||
|
"route": self.route,
|
||||||
|
**(self.params.dict() if self.params else {}),
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
try:
|
||||||
|
import mlflow.gateway
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import `mlflow.gateway` module. "
|
||||||
|
"Please install it with `pip install mlflow[gateway]`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
message_dicts = [
|
||||||
|
ChatMLflowAIGateway._convert_message_to_dict(message)
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"messages": message_dicts,
|
||||||
|
**(self.params.dict() if self.params else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = mlflow.gateway.query(self.route, data=data)
|
||||||
|
return ChatMLflowAIGateway._create_chat_result(resp)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
func = partial(
|
||||||
|
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
return self._default_params
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||||
|
return {
|
||||||
|
**self._default_params,
|
||||||
|
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "mlflow-ai-gateway-chat"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
content = _dict["content"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=content)
|
||||||
|
elif role == "assistant":
|
||||||
|
return AIMessage(content=content)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=content)
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=content, role=role)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _raise_functions_not_supported() -> None:
|
||||||
|
raise ValueError(
|
||||||
|
"Function messages are not supported by the MLflow AI Gateway. Please"
|
||||||
|
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
raise ValueError(
|
||||||
|
"Function messages are not supported by the MLflow AI Gateway. Please"
|
||||||
|
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown message type: {message}")
|
||||||
|
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
ChatMLflowAIGateway._raise_functions_not_supported()
|
||||||
|
if message.additional_kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"Additional message arguments are unsupported by MLflow AI Gateway "
|
||||||
|
" and will be ignored: %s",
|
||||||
|
message.additional_kwargs,
|
||||||
|
)
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for candidate in response["candidates"]:
|
||||||
|
message = ChatMLflowAIGateway._convert_dict_to_message(candidate["message"])
|
||||||
|
message_metadata = candidate.get("metadata", {})
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=dict(message_metadata),
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
|
||||||
|
response_metadata = response.get("metadata", {})
|
||||||
|
return ChatResult(generations=generations, llm_output=response_metadata)
|
@ -13,7 +13,22 @@ def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
|||||||
|
|
||||||
|
|
||||||
class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
|
class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
|
||||||
"""MLflow AI Gateway Embeddings APIs."""
|
"""
|
||||||
|
Wrapper around embeddings LLMs in the MLflow AI Gateway.
|
||||||
|
|
||||||
|
To use, you should have the ``mlflow[gateway]`` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import MlflowAIGatewayEmbeddings
|
||||||
|
|
||||||
|
embeddings = MlflowAIGatewayEmbeddings(
|
||||||
|
gateway_uri="<your-mlflow-ai-gateway-uri>",
|
||||||
|
route="<your-mlflow-ai-gateway-embeddings-route>"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
route: str
|
route: str
|
||||||
"""The route to use for the MLflow AI Gateway API."""
|
"""The route to use for the MLflow AI Gateway API."""
|
||||||
|
@ -19,7 +19,25 @@ class Params(BaseModel, extra=Extra.allow):
|
|||||||
|
|
||||||
|
|
||||||
class MlflowAIGateway(LLM):
|
class MlflowAIGateway(LLM):
|
||||||
"""The MLflow AI Gateway models."""
|
"""
|
||||||
|
Wrapper around completions LLMs in the MLflow AI Gateway.
|
||||||
|
|
||||||
|
To use, you should have the ``mlflow[gateway]`` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import MlflowAIGateway
|
||||||
|
|
||||||
|
completions = MlflowAIGateway(
|
||||||
|
gateway_uri="<your-mlflow-ai-gateway-uri>",
|
||||||
|
route="<your-mlflow-ai-gateway-completions-route>",
|
||||||
|
params={
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
route: str
|
route: str
|
||||||
gateway_uri: Optional[str] = None
|
gateway_uri: Optional[str] = None
|
||||||
|
Loading…
Reference in New Issue
Block a user