@ -1,19 +1,28 @@
""" Hugging Face Chat Wrapper. """
from typing import Any , List , Optional
from typing import Any , AsyncIterator , Iterator , List , Optional
from langchain_core . callbacks . manager import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
from langchain_core . language_models . chat_models import BaseChatModel
from langchain_core . language_models . chat_models import (
BaseChatModel ,
agenerate_from_stream ,
generate_from_stream ,
)
from langchain_core . messages import (
AIMessage ,
AIMessageChunk ,
BaseMessage ,
HumanMessage ,
SystemMessage ,
)
from langchain_core . outputs import ChatGeneration , ChatResult , LLMResult
from langchain_core . outputs import (
ChatGeneration ,
ChatGenerationChunk ,
ChatResult ,
LLMResult ,
)
from langchain_core . pydantic_v1 import root_validator
from langchain_community . llms . huggingface_endpoint import HuggingFaceEndpoint
@ -26,7 +35,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
class ChatHuggingFace ( BaseChatModel ) :
""" Hugging Face LLMs as ChatModels.
"""
Wrapper for using Hugging Face LLM ' s as ChatModels.
Works with ` HuggingFaceTextGenInference ` , ` HuggingFaceEndpoint ` ,
and ` HuggingFaceHub ` LLMs .
@ -44,6 +54,7 @@ class ChatHuggingFace(BaseChatModel):
system_message : SystemMessage = SystemMessage ( content = DEFAULT_SYSTEM_PROMPT )
tokenizer : Any = None
model_id : Optional [ str ] = None
streaming : bool = False
def __init__ ( self , * * kwargs : Any ) :
super ( ) . __init__ ( * * kwargs )
@ -70,6 +81,37 @@ class ChatHuggingFace(BaseChatModel):
)
return values
def _stream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ ChatGenerationChunk ] :
request = self . _to_chat_prompt ( messages )
for data in self . llm . stream ( request , * * kwargs ) :
delta = data
chunk = ChatGenerationChunk ( message = AIMessageChunk ( content = delta ) )
if run_manager :
run_manager . on_llm_new_token ( delta , chunk = chunk )
yield chunk
async def _astream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ ChatGenerationChunk ] :
request = self . _to_chat_prompt ( messages )
async for data in self . llm . astream ( request , * * kwargs ) :
delta = data
chunk = ChatGenerationChunk ( message = AIMessageChunk ( content = delta ) )
if run_manager :
await run_manager . on_llm_new_token ( delta , chunk = chunk )
yield chunk
def _generate (
self ,
messages : List [ BaseMessage ] ,
@ -77,6 +119,12 @@ class ChatHuggingFace(BaseChatModel):
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
if self . streaming :
stream_iter = self . _stream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return generate_from_stream ( stream_iter )
llm_input = self . _to_chat_prompt ( messages )
llm_result = self . llm . _generate (
prompts = [ llm_input ] , stop = stop , run_manager = run_manager , * * kwargs
@ -90,6 +138,12 @@ class ChatHuggingFace(BaseChatModel):
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
if self . streaming :
stream_iter = self . _astream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return await agenerate_from_stream ( stream_iter )
llm_input = self . _to_chat_prompt ( messages )
llm_result = await self . llm . _agenerate (
prompts = [ llm_input ] , stop = stop , run_manager = run_manager , * * kwargs