community[patch]: Add streaming logic in ChatHuggingFace (#18784)

- Add functions (_stream, _astream)
- Connect to _generate and _agenerate

Thank you for contributing to LangChain!

- [x] **PR title**: "community: Add streaming logic in ChatHuggingFace"

- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Addition functions (_stream, _astream) and connection
to _generate and _agenerate
    - **Issue:** #18782
    - **Dependencies:** none
    - **Twitter handle:** @lunara_x
pull/20059/head^2
Eun Hye Kim 6 months ago committed by GitHub
parent c05c379b26
commit b34f1086fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,19 +1,28 @@
"""Hugging Face Chat Wrapper."""
from typing import Any, List, Optional
from typing import Any, AsyncIterator, Iterator, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
@ -26,7 +35,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
class ChatHuggingFace(BaseChatModel):
"""Hugging Face LLMs as ChatModels.
"""
Wrapper for using Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
and `HuggingFaceHub` LLMs.
@ -44,6 +54,7 @@ class ChatHuggingFace(BaseChatModel):
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
model_id: Optional[str] = None
streaming: bool = False
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
@ -70,6 +81,37 @@ class ChatHuggingFace(BaseChatModel):
)
return values
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._to_chat_prompt(messages)
for data in self.llm.stream(request, **kwargs):
delta = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = self._to_chat_prompt(messages)
async for data in self.llm.astream(request, **kwargs):
delta = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _generate(
self,
messages: List[BaseMessage],
@ -77,6 +119,12 @@ class ChatHuggingFace(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
@ -90,6 +138,12 @@ class ChatHuggingFace(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs

Loading…
Cancel
Save