From b34f1086fe723c283e48bca63e7917600152b263 Mon Sep 17 00:00:00 2001 From: Eun Hye Kim Date: Wed, 17 Apr 2024 11:17:03 +0900 Subject: [PATCH] community[patch]: Add streaming logic in ChatHuggingFace (#18784) - Add functions (_stream, _astream) - Connect to _generate and _agenerate Thank you for contributing to LangChain! - [x] **PR title**: "community: Add streaming logic in ChatHuggingFace" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Addition functions (_stream, _astream) and connection to _generate and _agenerate - **Issue:** #18782 - **Dependencies:** none - **Twitter handle:** @lunara_x --- .../chat_models/huggingface.py | 64 +++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py index 598a8ebc49..2bf5b90948 100644 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ b/libs/community/langchain_community/chat_models/huggingface.py @@ -1,19 +1,28 @@ """Hugging Face Chat Wrapper.""" - -from typing import Any, List, Optional +from typing import Any, AsyncIterator, Iterator, List, Optional from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ) -from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, + LLMResult, +) from langchain_core.pydantic_v1 import root_validator from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint @@ -26,7 +35,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant." class ChatHuggingFace(BaseChatModel): - """Hugging Face LLMs as ChatModels. + """ + Wrapper for using Hugging Face LLM's as ChatModels. Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, and `HuggingFaceHub` LLMs. @@ -44,6 +54,7 @@ class ChatHuggingFace(BaseChatModel): system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) tokenizer: Any = None model_id: Optional[str] = None + streaming: bool = False def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -70,6 +81,37 @@ class ChatHuggingFace(BaseChatModel): ) return values + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + request = self._to_chat_prompt(messages) + + for data in self.llm.stream(request, **kwargs): + delta = data + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + request = self._to_chat_prompt(messages) + async for data in self.llm.astream(request, **kwargs): + delta = data + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + await run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + def _generate( self, messages: List[BaseMessage], @@ -77,6 +119,12 @@ class ChatHuggingFace(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + llm_input = self._to_chat_prompt(messages) llm_result = self.llm._generate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs @@ -90,6 +138,12 @@ class ChatHuggingFace(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + llm_input = self._to_chat_prompt(messages) llm_result = await self.llm._agenerate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs