mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
4a77a3ab19
This PR moves the validation of the decorator to a better place to avoid creating bugs while deprecating code. Prevent issues like this from arising: https://github.com/langchain-ai/langchain/issues/22510 we should replace with a linter at some point that just does static analysis
235 lines
7.7 KiB
Python
235 lines
7.7 KiB
Python
"""Hugging Face Chat Wrapper."""
|
|
|
|
from typing import Any, AsyncIterator, Iterator, List, Optional
|
|
|
|
from langchain_core._api.deprecation import deprecated
|
|
from langchain_core.callbacks.manager import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.chat_models import (
|
|
BaseChatModel,
|
|
agenerate_from_stream,
|
|
generate_from_stream,
|
|
)
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.outputs import (
|
|
ChatGeneration,
|
|
ChatGenerationChunk,
|
|
ChatResult,
|
|
LLMResult,
|
|
)
|
|
from langchain_core.pydantic_v1 import root_validator
|
|
|
|
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
|
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
|
from langchain_community.llms.huggingface_text_gen_inference import (
|
|
HuggingFaceTextGenInference,
|
|
)
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
|
|
|
|
|
@deprecated(
|
|
since="0.0.37",
|
|
removal="0.3",
|
|
alternative_import="langchain_huggingface.ChatHuggingFace",
|
|
)
|
|
class ChatHuggingFace(BaseChatModel):
|
|
"""
|
|
Wrapper for using Hugging Face LLM's as ChatModels.
|
|
|
|
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
|
and `HuggingFaceHub` LLMs.
|
|
|
|
Upon instantiating this class, the model_id is resolved from the url
|
|
provided to the LLM, and the appropriate tokenizer is loaded from
|
|
the HuggingFace Hub.
|
|
|
|
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
|
|
"""
|
|
|
|
llm: Any
|
|
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
|
HuggingFaceHub."""
|
|
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
|
tokenizer: Any = None
|
|
model_id: Optional[str] = None
|
|
streaming: bool = False
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
super().__init__(**kwargs)
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
self._resolve_model_id()
|
|
|
|
self.tokenizer = (
|
|
AutoTokenizer.from_pretrained(self.model_id)
|
|
if self.tokenizer is None
|
|
else self.tokenizer
|
|
)
|
|
|
|
@root_validator()
|
|
def validate_llm(cls, values: dict) -> dict:
|
|
if not isinstance(
|
|
values["llm"],
|
|
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
|
|
):
|
|
raise TypeError(
|
|
"Expected llm to be one of HuggingFaceTextGenInference, "
|
|
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
|
)
|
|
return values
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
request = self._to_chat_prompt(messages)
|
|
|
|
for data in self.llm.stream(request, **kwargs):
|
|
delta = data
|
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(delta, chunk=chunk)
|
|
yield chunk
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
request = self._to_chat_prompt(messages)
|
|
async for data in self.llm.astream(request, **kwargs):
|
|
delta = data
|
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
|
yield chunk
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if self.streaming:
|
|
stream_iter = self._stream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return generate_from_stream(stream_iter)
|
|
|
|
llm_input = self._to_chat_prompt(messages)
|
|
llm_result = self.llm._generate(
|
|
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return self._to_chat_result(llm_result)
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if self.streaming:
|
|
stream_iter = self._astream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return await agenerate_from_stream(stream_iter)
|
|
|
|
llm_input = self._to_chat_prompt(messages)
|
|
llm_result = await self.llm._agenerate(
|
|
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return self._to_chat_result(llm_result)
|
|
|
|
def _to_chat_prompt(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
) -> str:
|
|
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
|
if not messages:
|
|
raise ValueError("At least one HumanMessage must be provided!")
|
|
|
|
if not isinstance(messages[-1], HumanMessage):
|
|
raise ValueError("Last message must be a HumanMessage!")
|
|
|
|
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
|
|
|
return self.tokenizer.apply_chat_template(
|
|
messages_dicts, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
|
"""Convert LangChain message to ChatML format."""
|
|
|
|
if isinstance(message, SystemMessage):
|
|
role = "system"
|
|
elif isinstance(message, AIMessage):
|
|
role = "assistant"
|
|
elif isinstance(message, HumanMessage):
|
|
role = "user"
|
|
else:
|
|
raise ValueError(f"Unknown message type: {type(message)}")
|
|
|
|
return {"role": role, "content": message.content}
|
|
|
|
@staticmethod
|
|
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
|
|
chat_generations = []
|
|
|
|
for g in llm_result.generations[0]:
|
|
chat_generation = ChatGeneration(
|
|
message=AIMessage(content=g.text), generation_info=g.generation_info
|
|
)
|
|
chat_generations.append(chat_generation)
|
|
|
|
return ChatResult(
|
|
generations=chat_generations, llm_output=llm_result.llm_output
|
|
)
|
|
|
|
def _resolve_model_id(self) -> None:
|
|
"""Resolve the model_id from the LLM's inference_server_url"""
|
|
|
|
from huggingface_hub import list_inference_endpoints
|
|
|
|
available_endpoints = list_inference_endpoints("*")
|
|
if isinstance(self.llm, HuggingFaceHub) or (
|
|
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
|
):
|
|
self.model_id = self.llm.repo_id
|
|
return
|
|
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
|
endpoint_url: Optional[str] = self.llm.inference_server_url
|
|
else:
|
|
endpoint_url = self.llm.endpoint_url
|
|
|
|
for endpoint in available_endpoints:
|
|
if endpoint.url == endpoint_url:
|
|
self.model_id = endpoint.repository
|
|
|
|
if not self.model_id:
|
|
raise ValueError(
|
|
"Failed to resolve model_id:"
|
|
f"Could not find model id for inference server: {endpoint_url}"
|
|
"Make sure that your Hugging Face token has access to the endpoint."
|
|
)
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "huggingface-chat-wrapper"
|