"""Hugging Face Chat Wrapper.""" from typing import Any, AsyncIterator, Iterator, List, Optional from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ) from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, ChatResult, LLMResult, ) from langchain_core.pydantic_v1 import root_validator from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_community.llms.huggingface_hub import HuggingFaceHub from langchain_community.llms.huggingface_text_gen_inference import ( HuggingFaceTextGenInference, ) DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" class ChatHuggingFace(BaseChatModel): """ Wrapper for using Hugging Face LLM's as ChatModels. Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, and `HuggingFaceHub` LLMs. Upon instantiating this class, the model_id is resolved from the url provided to the LLM, and the appropriate tokenizer is loaded from the HuggingFace Hub. Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat """ llm: Any """LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or HuggingFaceHub.""" system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) tokenizer: Any = None model_id: Optional[str] = None streaming: bool = False def __init__(self, **kwargs: Any): super().__init__(**kwargs) from transformers import AutoTokenizer self._resolve_model_id() self.tokenizer = ( AutoTokenizer.from_pretrained(self.model_id) if self.tokenizer is None else self.tokenizer ) @root_validator() def validate_llm(cls, values: dict) -> dict: if not isinstance( values["llm"], (HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub), ): raise TypeError( "Expected llm to be one of HuggingFaceTextGenInference, " f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}" ) return values def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: request = self._to_chat_prompt(messages) for data in self.llm.stream(request, **kwargs): delta = data chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: request = self._to_chat_prompt(messages) async for data in self.llm.astream(request, **kwargs): delta = data chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) if run_manager: await run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) llm_input = self._to_chat_prompt(messages) llm_result = self.llm._generate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs ) return self._to_chat_result(llm_result) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) llm_input = self._to_chat_prompt(messages) llm_result = await self.llm._agenerate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs ) return self._to_chat_result(llm_result) def _to_chat_prompt( self, messages: List[BaseMessage], ) -> str: """Convert a list of messages into a prompt format expected by wrapped LLM.""" if not messages: raise ValueError("At least one HumanMessage must be provided!") if not isinstance(messages[-1], HumanMessage): raise ValueError("Last message must be a HumanMessage!") messages_dicts = [self._to_chatml_format(m) for m in messages] return self.tokenizer.apply_chat_template( messages_dicts, tokenize=False, add_generation_prompt=True ) def _to_chatml_format(self, message: BaseMessage) -> dict: """Convert LangChain message to ChatML format.""" if isinstance(message, SystemMessage): role = "system" elif isinstance(message, AIMessage): role = "assistant" elif isinstance(message, HumanMessage): role = "user" else: raise ValueError(f"Unknown message type: {type(message)}") return {"role": role, "content": message.content} @staticmethod def _to_chat_result(llm_result: LLMResult) -> ChatResult: chat_generations = [] for g in llm_result.generations[0]: chat_generation = ChatGeneration( message=AIMessage(content=g.text), generation_info=g.generation_info ) chat_generations.append(chat_generation) return ChatResult( generations=chat_generations, llm_output=llm_result.llm_output ) def _resolve_model_id(self) -> None: """Resolve the model_id from the LLM's inference_server_url""" from huggingface_hub import list_inference_endpoints available_endpoints = list_inference_endpoints("*") if isinstance(self.llm, HuggingFaceHub) or ( hasattr(self.llm, "repo_id") and self.llm.repo_id ): self.model_id = self.llm.repo_id return elif isinstance(self.llm, HuggingFaceTextGenInference): endpoint_url: Optional[str] = self.llm.inference_server_url else: endpoint_url = self.llm.endpoint_url for endpoint in available_endpoints: if endpoint.url == endpoint_url: self.model_id = endpoint.repository if not self.model_id: raise ValueError( "Failed to resolve model_id:" f"Could not find model id for inference server: {endpoint_url}" "Make sure that your Hugging Face token has access to the endpoint." ) @property def _llm_type(self) -> str: return "huggingface-chat-wrapper"