community[patch]: Make ChatDatabricks model supports streaming response (#19912)

**Description:** Make ChatDatabricks model supports stream
**Issue:** N/A
**Dependencies:** MLflow nightly build version (we will release next
MLflow version soon)
**Twitter handle:** N/A

Manually test:

(Before testing, please install `pip install
git+https://github.com/mlflow/mlflow.git`)

```python
# Test Databricks Foundation LLM model
from langchain.chat_models import ChatDatabricks

chat_model = ChatDatabricks(
    endpoint="databricks-llama-2-70b-chat",
    max_tokens=500
)
from langchain_core.messages import AIMessageChunk

for chunk in chat_model.stream("What is mlflow?"):
  print(chunk.content, end="|")
```

- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.

---------

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/20373/head^2
WeichenXu 6 months ago committed by GitHub
parent a892f985d3
commit e9fc87aab1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -19,9 +19,16 @@ class ChatDatabricks(ChatMlflow):
chat = ChatDatabricks( chat = ChatDatabricks(
target_uri="databricks", target_uri="databricks",
endpoint="chat", endpoint="databricks-llama-2-70b-chat",
temperature-0.1, temperature-0.1,
) )
# single input invocation
print(chat_model.invoke("What is MLflow?").content)
# single input invocation with streaming response
for chunk in chat_model.stream("What is MLflow?"):
print(chunk.content, end="|")
""" """
target_uri: str = "databricks" target_uri: str = "databricks"

@ -1,24 +1,29 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
from urllib.parse import urlparse from urllib.parse import urlparse
from langchain_core.callbacks import ( from langchain_core.callbacks import CallbackManagerForLLMRun
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk,
FunctionMessage, FunctionMessage,
HumanMessage, HumanMessage,
HumanMessageChunk,
SystemMessage, SystemMessage,
SystemMessageChunk,
) )
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import ( from langchain_core.pydantic_v1 import (
Field, Field,
PrivateAttr, PrivateAttr,
) )
from langchain_core.runnables import RunnableConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,13 +103,12 @@ class ChatMlflow(BaseChatModel):
} }
return params return params
def _generate( def _prepare_inputs(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> Dict[str, Any]:
message_dicts = [ message_dicts = [
ChatMlflow._convert_message_to_dict(message) for message in messages ChatMlflow._convert_message_to_dict(message) for message in messages
] ]
@ -119,9 +123,76 @@ class ChatMlflow(BaseChatModel):
data["stop"] = stop data["stop"] = stop
if self.max_tokens is not None: if self.max_tokens is not None:
data["max_tokens"] = self.max_tokens data["max_tokens"] = self.max_tokens
return data
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
data = self._prepare_inputs(
messages,
stop,
**kwargs,
)
resp = self._client.predict(endpoint=self.endpoint, inputs=data) resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return ChatMlflow._create_chat_result(resp) return ChatMlflow._create_chat_result(resp)
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
# We need to override `stream` to handle the case
# that `self._client` does not implement `predict_stream`
if not hasattr(self._client, "predict_stream"):
# MLflow deployment client does not implement streaming,
# so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
yield from super().stream(input, config, stop=stop, **kwargs)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
data = self._prepare_inputs(
messages,
stop,
**kwargs,
)
# TODO: check if `_client.predict_stream` is available.
chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data)
for chunk in chunk_iter:
choice = chunk["choices"][0]
chunk = ChatMlflow._convert_delta_to_message_chunk(choice["delta"])
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
return self._default_params return self._default_params
@ -153,6 +224,19 @@ class ChatMlflow(BaseChatModel):
else: else:
return ChatMessage(content=content, role=role) return ChatMessage(content=content, role=role)
@staticmethod
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
role = _dict["role"]
content = _dict["content"]
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
return AIMessageChunk(content=content)
elif role == "system":
return SystemMessageChunk(content=content)
else:
return ChatMessageChunk(content=content, role=role)
@staticmethod @staticmethod
def _raise_functions_not_supported() -> None: def _raise_functions_not_supported() -> None:
raise ValueError( raise ValueError(

Loading…
Cancel
Save