langchain/libs/langchain/tests/unit_tests/chat_models
Harutaka Kawamura 0d08a692a3
langchain[minor]: Migrate mlflow and databricks classes to deployments APIs. (#13699)
## Description

Related to https://github.com/mlflow/mlflow/pull/10420. MLflow AI
gateway will be deprecated and replaced by the `mlflow.deployments`
module. Happy to split this PR if it's too large.

```
pip install git+https://github.com/langchain-ai/langchain.git@refs/pull/13699/merge#subdirectory=libs/langchain
```

## Dependencies

Install mlflow from https://github.com/mlflow/mlflow/pull/10420:

```
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10420/merge
```

## Testing plan

The following code works fine on local and databricks:

<details><summary>Click</summary>
<p>

```python
"""
Setup
-----
mlflow deployments start-server --config-path examples/gateway/openai/config.yaml
databricks secrets create-scope <scope>
databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY

Run
---
python /path/to/this/file.py secrets/<scope>/openai-api-key
"""
from langchain.chat_models import ChatMlflow, ChatDatabricks
from langchain.embeddings import MlflowEmbeddings, DatabricksEmbeddings
from langchain.llms import Databricks, Mlflow
from langchain.schema.messages import HumanMessage
from langchain.chains.loading import load_chain
from mlflow.deployments import get_deploy_client
import uuid
import sys
import tempfile
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

###############################
# MLflow
###############################
chat = ChatMlflow(
    target_uri="http://127.0.0.1:5000", endpoint="chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))

embeddings = MlflowEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings")
print(embeddings.embed_query("hello")[:3])
print(embeddings.embed_documents(["hello", "world"])[0][:3])

llm = Mlflow(
    target_uri="http://127.0.0.1:5000",
    endpoint="completions",
    params={"temperature": 0.1},
)
print(llm("I am"))

llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
print(llm_chain.run(adjective="funny"))

# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    path = f"{tmpdir}/llm.yaml"
    llm_chain.save(path)
    loaded_chain = load_chain(path)
    print(loaded_chain("funny"))

###############################
# Databricks
###############################
secret = sys.argv[1]
client = get_deploy_client("databricks")

# External - chat
name = f"chat-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "gpt-4",
                    "provider": "openai",
                    "task": "llm/v1/chat",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    chat = ChatDatabricks(
        target_uri="databricks", endpoint=name, params={"temperature": 0.1}
    )
    print(chat([HumanMessage(content="hello")]))
finally:
    client.delete_endpoint(endpoint=name)

# External - embeddings
name = f"embeddings-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "text-embedding-ada-002",
                    "provider": "openai",
                    "task": "llm/v1/embeddings",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    embeddings = DatabricksEmbeddings(target_uri="databricks", endpoint=name)
    print(embeddings.embed_query("hello")[:3])
    print(embeddings.embed_documents(["hello", "world"])[0][:3])
finally:
    client.delete_endpoint(endpoint=name)

# External - completions
name = f"completions-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "gpt-3.5-turbo-instruct",
                    "provider": "openai",
                    "task": "llm/v1/completions",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    llm = Databricks(
        endpoint_name=name,
        model_kwargs={"temperature": 0.1},
    )
    print(llm("I am"))
finally:
    client.delete_endpoint(endpoint=name)


# Foundation model - chat
chat = ChatDatabricks(
    endpoint="databricks-llama-2-70b-chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))

# Foundation model - embeddings
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
print(embeddings.embed_query("hello")[:3])

# Foundation model - completions
llm = Databricks(
    endpoint_name="databricks-mpt-7b-instruct", model_kwargs={"temperature": 0.1}
)
print(llm("hello"))
llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
print(llm_chain.run(adjective="funny"))

# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    path = f"{tmpdir}/llm.yaml"
    llm_chain.save(path)
    loaded_chain = load_chain(path)
    print(loaded_chain("funny"))

```

Output:

```
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
sorry, but I cannot continue the sentence as it is incomplete. Can you please provide more information or context?
Sure, here's a classic one for you:

Why don't scientists trust atoms?

Because they make up everything!
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmpx_4no6ad
{'adjective': 'funny', 'text': "Sure, here's a classic one for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"}
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
 a 23 year old female and I am currently studying for my master's degree
content="\nHello! It's nice to meet you. Is there something I can help you with or would you like to chat for a bit?"
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]

hello back
 Well, I don't really know many jokes, but I do know this funny story...
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmp7_ds72ex
{'adjective': 'funny', 'text': " Well, I don't really know many jokes, but I do know this funny story..."}
```

</p>
</details>

The existing workflow doesn't break:

<details><summary>click</summary>
<p>

```python
import uuid

import mlflow
from mlflow.models import ModelSignature
from mlflow.types.schema import ColSpec, Schema


class MyModel(mlflow.pyfunc.PythonModel):
    def predict(self, context, model_input):
        return str(uuid.uuid4())


with mlflow.start_run():
    mlflow.pyfunc.log_model(
        "model",
        python_model=MyModel(),
        pip_requirements=["mlflow==2.8.1", "cloudpickle<3"],
        signature=ModelSignature(
            inputs=Schema(
                [
                    ColSpec("string", "prompt"),
                    ColSpec("string", "stop"),
                ]
            ),
            outputs=Schema(
                [
                    ColSpec(name=None, type="string"),
                ]
            ),
        ),
        registered_model_name=f"lang-{uuid.uuid4()}",
    )

# Manually create a serving endpoint with the registered model and run
from langchain.llms import Databricks

llm = Databricks(endpoint_name="<name>")
llm("hello")  # 9d0b2491-3d13-487c-bc02-1287f06ecae7
```

</p>
</details> 

## Follow-up tasks

(This PR is too large. I'll file a separate one for follow-up tasks.)

- Update `docs/docs/integrations/providers/mlflow_ai_gateway.mdx` and
`docs/docs/integrations/providers/databricks.md`.

---------

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-30 15:06:58 -08:00
..
__init__.py (WIP) set up experimental (#7959) 2023-07-21 09:20:24 -07:00
test_anthropic.py REFACTOR: Refactor langchain_core (#13627) 2023-11-21 08:35:29 -08:00
test_azureml_endpoint.py Separate out langchain_core package (#13577) 2023-11-20 13:09:30 -08:00
test_azureopenai.py ChatOpenAI and AzureChatOpenAI openai>=1 compatible (#12948) 2023-11-06 13:24:18 -08:00
test_baichuan.py secretStr for baichuan chat model api key (#13946) 2023-11-28 21:20:23 -05:00
test_base.py langchain[patch], core[patch]: Make common utils public (#13932) 2023-11-27 15:34:46 -08:00
test_bedrock.py BUGFIX: anthropic models on bedrock (#13629) 2023-11-21 17:40:29 -08:00
test_ernie.py REFACTOR: Refactor langchain_core (#13627) 2023-11-21 08:35:29 -08:00
test_fireworks.py Separate out langchain_core package (#13577) 2023-11-20 13:09:30 -08:00
test_google_palm.py REFACTOR: Refactor langchain_core (#13627) 2023-11-21 08:35:29 -08:00
test_hunyuan.py REFACTOR: Refactor langchain_core (#13627) 2023-11-21 08:35:29 -08:00
test_imports.py langchain[minor]: Migrate mlflow and databricks classes to deployments APIs. (#13699) 2023-11-30 15:06:58 -08:00
test_javelin_ai_gateway.py Separate out langchain_core package (#13577) 2023-11-20 13:09:30 -08:00
test_openai.py REFACTOR: Refactor langchain_core (#13627) 2023-11-21 08:35:29 -08:00