mirror of https://github.com/hwchase17/langchain
Add integration for MLflow AI Gateway (#7113)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> - Adds integration for MLflow AI Gateway (this will be shipped in MLflow 2.5 this week). Manual testing: ```sh # Move to mlflow repo cd /path/to/mlflow # install langchain pip install git+https://github.com/harupy/langchain.git@gateway-integration # launch gateway service mlflow gateway start --config-path examples/gateway/openai/config.yaml # Then, run the examples in this PR ```pull/7939/head
parent
6792a3557d
commit
f6839a8682
@ -0,0 +1,116 @@
|
||||
# MLflow AI Gateway
|
||||
|
||||
The MLflow AI Gateway service is a powerful tool designed to streamline the usage and management of various large language model (LLM) providers, such as OpenAI and Anthropic, within an organization. It offers a high-level interface that simplifies the interaction with these services by providing a unified endpoint to handle specific LLM related requests. See [the MLflow AI Gateway documentation](https://mlflow.org/docs/latest/gateway/index.html) for more details.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Install `mlflow` with MLflow AI Gateway dependencies:
|
||||
|
||||
```sh
|
||||
pip install 'mlflow[gateway]'
|
||||
```
|
||||
|
||||
Set the OpenAI API key as an environment variable:
|
||||
|
||||
```sh
|
||||
export OPENAI_API_KEY=...
|
||||
```
|
||||
|
||||
Create a configuration file:
|
||||
|
||||
```yaml
|
||||
routes:
|
||||
- name: completions
|
||||
type: llm/v1/completions
|
||||
model:
|
||||
provider: openai
|
||||
name: text-davinci-003
|
||||
config:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
|
||||
- name: embeddings
|
||||
type: llm/v1/embeddings
|
||||
model:
|
||||
provider: openai
|
||||
name: text-embedding-ada-002
|
||||
config:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start the Gateway server:
|
||||
|
||||
```sh
|
||||
mlflow gateway start --config-path /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Completions Example
|
||||
|
||||
```python
|
||||
import mlflow
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.llms import MlflowAIGateway
|
||||
|
||||
gateway = MlflowAIGateway(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="completions",
|
||||
params={
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=gateway,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
template="Tell me a {adjective} joke",
|
||||
),
|
||||
)
|
||||
result = llm_chain.run(adjective="funny")
|
||||
print(result)
|
||||
|
||||
with mlflow.start_run():
|
||||
model_info = mlflow.langchain.log_model(chain, "model")
|
||||
|
||||
model = mlflow.pyfunc.load_model(model_info.model_uri)
|
||||
print(model.predict([{"adjective": "funny"}]))
|
||||
```
|
||||
|
||||
## Embeddings Example
|
||||
|
||||
```python
|
||||
from langchain.embeddings import MlflowAIGatewayEmbeddings
|
||||
|
||||
embeddings = MlflowAIGatewayEmbeddings(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="embeddings",
|
||||
)
|
||||
|
||||
print(embeddings.embed_query("hello"))
|
||||
print(embeddings.embed_documents(["hello"]))
|
||||
```
|
||||
|
||||
## Databricks MLflow AI Gateway
|
||||
|
||||
Databricks MLflow AI Gateway is in private preview.
|
||||
Please contact a Databricks representative to enroll in the preview.
|
||||
|
||||
```python
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.llms import MlflowAIGateway
|
||||
|
||||
gateway = MlflowAIGateway(
|
||||
gateway_uri="databricks",
|
||||
route="completions",
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=gateway,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
template="Tell me a {adjective} joke",
|
||||
),
|
||||
)
|
||||
result = llm_chain.run(adjective="funny")
|
||||
print(result)
|
||||
```
|
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), size):
|
||||
yield texts[i : i + size]
|
||||
|
||||
|
||||
class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
|
||||
route: str
|
||||
gateway_uri: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import `mlflow.gateway` module. "
|
||||
"Please install it with `pip install mlflow[gateway]`."
|
||||
) from e
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if self.gateway_uri:
|
||||
mlflow.gateway.set_gateway_uri(self.gateway_uri)
|
||||
|
||||
def _query(self, texts: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import `mlflow.gateway` module. "
|
||||
"Please install it with `pip install mlflow[gateway]`."
|
||||
) from e
|
||||
|
||||
embeddings = []
|
||||
for txt in _chunk(texts, 20):
|
||||
resp = mlflow.gateway.query(self.route, data={"text": txt})
|
||||
embeddings.append(resp["embeddings"])
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._query(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._query([text])[0]
|
@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class Params(BaseModel, extra=Extra.allow):
|
||||
temperature: float = 0.0
|
||||
candidate_count: int = 1
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class MlflowAIGateway(LLM):
|
||||
route: str
|
||||
gateway_uri: Optional[str] = None
|
||||
params: Optional[Params] = None
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import `mlflow.gateway` module. "
|
||||
"Please install it with `pip install mlflow[gateway]`."
|
||||
) from e
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if self.gateway_uri:
|
||||
mlflow.gateway.set_gateway_uri(self.gateway_uri)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"gateway_uri": self.gateway_uri,
|
||||
"route": self.route,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
return params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return self._default_params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import `mlflow.gateway` module. "
|
||||
"Please install it with `pip install mlflow[gateway]`."
|
||||
) from e
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
if s := (stop or (self.params.stop if self.params else None)):
|
||||
data["stop"] = s
|
||||
resp = mlflow.gateway.query(self.route, data=data)
|
||||
return resp["candidates"][0]["text"]
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mlflow-ai-gateway"
|
Loading…
Reference in New Issue