community[minor]: add hf chat wrapper (#14736)

Builds on #14040 with community refactor merged and notebook updated.

Note that with this refactor, models will be imported from
`langchain_community.chat_models.huggingface` rather than the main
`langchain` repo.

---------

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: ugm2 <unaigaraymaestre@gmail.com>
Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu>
Co-authored-by: Andrew Reed <andrew.reed.r@gmail.com>
Co-authored-by: Andrew Reed <areed1242@gmail.com>
Co-authored-by: A-Roucher <aymeric.roucher@gmail.com>
Co-authored-by: Aymeric Roucher <69208727+A-Roucher@users.noreply.github.com>
pull/15021/head^2
Jacob Lee 9 months ago committed by GitHub
parent b99274c9d8
commit 1b01ee0e3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,456 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hugging Face Chat Wrapper\n",
"\n",
"This notebook shows how to get started using Hugging Face LLM's as chat models.\n",
"\n",
"In particular, we will:\n",
"1. Utilize the [HuggingFaceTextGenInference](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_text_gen_inference.py), [HuggingFaceEndpoint](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_endpoint.py), or [HuggingFaceHub](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_hub.py) integrations to instantiate an `LLM`.\n",
"2. Utilize the `ChatHuggingFace` class to enable any of these LLMs to interface with LangChain's [Chat Messages](https://python.langchain.com/docs/modules/model_io/chat/#messages) abstraction.\n",
"3. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n",
"\n",
"\n",
"> Note: To get started, you'll need to have a [Hugging Face Access Token](https://huggingface.co/docs/hub/security-tokens) saved as an environment variable: `HUGGINGFACEHUB_API_TOKEN`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 23.3.1 is available.\n",
"You should consider upgrading via the '/Users/jacoblee/langchain/langchain/libs/langchain/.venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -q text-generation transformers google-search-results numexpr langchainhub sentencepiece jinja2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Instantiate an LLM\n",
"\n",
"There are three LLM options to choose from."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `HuggingFaceTextGenInference`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jacoblee/langchain/langchain/libs/langchain/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"\n",
"from langchain_community.llms import HuggingFaceTextGenInference\n",
"\n",
"ENDPOINT_URL = \"<YOUR_ENDPOINT_URL_HERE>\"\n",
"HF_TOKEN = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")\n",
"\n",
"llm = HuggingFaceTextGenInference(\n",
" inference_server_url=ENDPOINT_URL,\n",
" max_new_tokens=512,\n",
" top_k=50,\n",
" temperature=0.1,\n",
" repetition_penalty=1.03,\n",
" server_kwargs={\n",
" \"headers\": {\n",
" \"Authorization\": f\"Bearer {HF_TOKEN}\",\n",
" \"Content-Type\": \"application/json\",\n",
" }\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `HuggingFaceEndpoint`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import HuggingFaceEndpoint\n",
"\n",
"ENDPOINT_URL = \"<YOUR_ENDPOINT_URL_HERE>\"\n",
"llm = HuggingFaceEndpoint(\n",
" endpoint_url=ENDPOINT_URL,\n",
" task=\"text-generation\",\n",
" model_kwargs={\n",
" \"max_new_tokens\": 512,\n",
" \"top_k\": 50,\n",
" \"temperature\": 0.1,\n",
" \"repetition_penalty\": 1.03,\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `HuggingFaceHub`"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jacoblee/langchain/langchain/libs/langchain/.venv/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:127: FutureWarning: '__init__' (from 'huggingface_hub.inference_api') is deprecated and will be removed from version '1.0'. `InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out this guide to learn how to convert your script to use it: https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client.\n",
" warnings.warn(warning_message, FutureWarning)\n"
]
}
],
"source": [
"from langchain_community.llms import HuggingFaceHub\n",
"\n",
"llm = HuggingFaceHub(\n",
" repo_id=\"HuggingFaceH4/zephyr-7b-beta\",\n",
" task=\"text-generation\",\n",
" model_kwargs={\n",
" \"max_new_tokens\": 512,\n",
" \"top_k\": 30,\n",
" \"temperature\": 0.1,\n",
" \"repetition_penalty\": 1.03,\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Instantiate the `ChatHuggingFace` to apply chat templates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Instantiate the chat model and some messages to pass."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING! repo_id is not default parameter.\n",
" repo_id was transferred to model_kwargs.\n",
" Please confirm that repo_id is what you intended.\n",
"WARNING! task is not default parameter.\n",
" task was transferred to model_kwargs.\n",
" Please confirm that task is what you intended.\n",
"WARNING! huggingfacehub_api_token is not default parameter.\n",
" huggingfacehub_api_token was transferred to model_kwargs.\n",
" Please confirm that huggingfacehub_api_token is what you intended.\n",
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
]
}
],
"source": [
"from langchain.schema import (\n",
" HumanMessage,\n",
" SystemMessage,\n",
")\n",
"from langchain_community.chat_models.huggingface import ChatHuggingFace\n",
"\n",
"messages = [\n",
" SystemMessage(content=\"You're a helpful assistant\"),\n",
" HumanMessage(\n",
" content=\"What happens when an unstoppable force meets an immovable object?\"\n",
" ),\n",
"]\n",
"\n",
"chat_model = ChatHuggingFace(llm=llm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspect which model and corresponding chat template is being used."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'HuggingFaceH4/zephyr-7b-beta'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_model.model_id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspect how the chat messages are formatted for the LLM call."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"<|system|>\\nYou're a helpful assistant</s>\\n<|user|>\\nWhat happens when an unstoppable force meets an immovable object?</s>\\n<|assistant|>\\n\""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_model._to_chat_prompt(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the model."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"According to a popular philosophical paradox, when an unstoppable force meets an immovable object, it is impossible to determine which one will prevail because both are defined as being completely unyielding and unmovable. The paradox suggests that the very concepts of \"unstoppable force\" and \"immovable object\" are inherently contradictory, and therefore, it is illogical to imagine a scenario where they would meet and interact. However, in practical terms, it is highly unlikely for such a scenario to occur in the real world, as the concepts of \"unstoppable force\" and \"immovable object\" are often used metaphorically to describe hypothetical situations or abstract concepts, rather than physical objects or forces.\n"
]
}
],
"source": [
"res = chat_model.invoke(messages)\n",
"print(res.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Take it for a spin as an agent!\n",
"\n",
"Here we'll test out `Zephyr-7B-beta` as a zero-shot ReAct Agent. The example below is taken from [here](https://python.langchain.com/docs/modules/agents/agent_types/react#using-chat-models).\n",
"\n",
"> Note: To run this section, you'll need to have a [SerpAPI Token](https://serpapi.com/) saved as an environment variable: `SERPAPI_API_KEY`"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from langchain import hub\n",
"from langchain.agents import AgentExecutor, load_tools\n",
"from langchain.agents.format_scratchpad import format_log_to_str\n",
"from langchain.agents.output_parsers import (\n",
" ReActJsonSingleInputOutputParser,\n",
")\n",
"from langchain.tools.render import render_text_description\n",
"from langchain.utilities import SerpAPIWrapper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Configure the agent with a `react-json` style prompt and access to a search engine and calculator."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# setup tools\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
"\n",
"# setup ReAct style prompt\n",
"prompt = hub.pull(\"hwchase17/react-json\")\n",
"prompt = prompt.partial(\n",
" tools=render_text_description(tools),\n",
" tool_names=\", \".join([t.name for t in tools]),\n",
")\n",
"\n",
"# define the agent\n",
"chat_model_with_stop = chat_model.bind(stop=[\"\\nObservation\"])\n",
"agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_log_to_str(x[\"intermediate_steps\"]),\n",
" }\n",
" | prompt\n",
" | chat_model_with_stop\n",
" | ReActJsonSingleInputOutputParser()\n",
")\n",
"\n",
"# instantiate AgentExecutor\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mQuestion: Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\n",
"\n",
"Thought: I need to use the Search tool to find out who Leo DiCaprio's current girlfriend is. Then, I can use the Calculator tool to raise her current age to the power of 0.43.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Search\",\n",
" \"action_input\": \"leo dicaprio girlfriend\"\n",
"}\n",
"```\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mLeonardo DiCaprio may have found The One in Vittoria Ceretti. “They are in love,” a source exclusively reveals in the latest issue of Us Weekly. “Leo was clearly very proud to be showing Vittoria off and letting everyone see how happy they are together.”\u001b[0m\u001b[32;1m\u001b[1;3mNow that we know Leo DiCaprio's current girlfriend is Vittoria Ceretti, let's find out her current age.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Search\",\n",
" \"action_input\": \"vittoria ceretti age\"\n",
"}\n",
"```\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mNow that we know Vittoria Ceretti's current age is 25, let's use the Calculator tool to raise it to the power of 0.43.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Calculator\",\n",
" \"action_input\": \"25^0.43\"\n",
"}\n",
"```\n",
"\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mFinal Answer: Vittoria Ceretti, Leo DiCaprio's current girlfriend, when raised to the power of 0.43 is approximately 4.0 rounded to two decimal places. Her current age is 25 years old.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\",\n",
" 'output': \"Vittoria Ceretti, Leo DiCaprio's current girlfriend, when raised to the power of 0.43 is approximately 4.0 rounded to two decimal places. Her current age is 25 years old.\"}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" {\n",
" \"input\": \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Wahoo! Our open-source 7b parameter Zephyr model was able to:\n",
"\n",
"1. Plan out a series of actions: `I need to use the Search tool to find out who Leo DiCaprio's current girlfriend is. Then, I can use the Calculator tool to raise her current age to the power of 0.43.`\n",
"2. Then execute a search using the SerpAPI tool to find who Leo DiCaprio's current girlfriend is\n",
"3. Execute another search to find her age\n",
"4. And finally use a calculator tool to calculate her age raised to the power of 0.43\n",
"\n",
"It's exciting to see how far open-source LLM's can go as general purpose reasoning agents. Give it a try yourself!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -32,6 +32,7 @@ from langchain_community.chat_models.fireworks import ChatFireworks
from langchain_community.chat_models.gigachat import GigaChat
from langchain_community.chat_models.google_palm import ChatGooglePalm
from langchain_community.chat_models.gpt_router import GPTRouter
from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain_community.chat_models.human import HumanInputChatModel
from langchain_community.chat_models.hunyuan import ChatHunyuan
from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
@ -65,6 +66,7 @@ __all__ = [
"ChatOllama",
"ChatVertexAI",
"JinaChat",
"ChatHuggingFace",
"HumanInputChatModel",
"MiniMaxChat",
"ChatAnyscale",

@ -0,0 +1,166 @@
"""Hugging Face Chat Wrapper."""
from typing import Any, List, Optional, Union
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference,
)
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
class ChatHuggingFace(BaseChatModel):
"""
Wrapper for using Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
and `HuggingFaceHub` LLMs.
Upon instantiating this class, the model_id is resolved from the url
provided to the LLM, and the appropriate tokenizer is loaded from
the HuggingFace Hub.
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
"""
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
model_id: str = None # type: ignore
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from transformers import AutoTokenizer
self._resolve_model_id()
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
else self.tokenizer
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _to_chat_prompt(
self,
messages: List[BaseMessage],
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("at least one HumanMessage must be provided")
if not isinstance(messages[-1], HumanMessage):
raise ValueError("last message must be a HumanMessage")
messages_dicts = [self._to_chatml_format(m) for m in messages]
return self.tokenizer.apply_chat_template(
messages_dicts, tokenize=False, add_generation_prompt=True
)
def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(f"Unknown message type: {type(message)}")
return {"role": role, "content": message.content}
@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []
for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
)
chat_generations.append(chat_generation)
return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)
def _resolve_model_id(self) -> None:
"""Resolve the model_id from the LLM's inference_server_url"""
from huggingface_hub import list_inference_endpoints
available_endpoints = list_inference_endpoints("*")
if isinstance(self.llm, HuggingFaceTextGenInference):
endpoint_url = self.llm.inference_server_url
elif isinstance(self.llm, HuggingFaceEndpoint):
endpoint_url = self.llm.endpoint_url
elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM
self.model_id = self.llm.repo_id
return
else:
raise ValueError(f"Unknown LLM type: {type(self.llm)}")
for endpoint in available_endpoints:
if endpoint.url == endpoint_url:
self.model_id = endpoint.repository
if not self.model_id:
raise ValueError(
"Failed to resolve model_id"
f"Could not find model id for inference server provided: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint."
)
@property
def _llm_type(self) -> str:
return "huggingface-chat-wrapper"

@ -0,0 +1,11 @@
"""Test HuggingFace Chat wrapper."""
from importlib import import_module
def test_import_class() -> None:
"""Test that the class can be imported."""
module_name = "langchain_community.chat_models.huggingface"
class_name = "ChatHuggingFace"
module = import_module(module_name)
assert hasattr(module, class_name)

@ -11,6 +11,7 @@ EXPECTED_ALL = [
"ChatCohere",
"ChatDatabricks",
"ChatGooglePalm",
"ChatHuggingFace",
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatOllama",

Loading…
Cancel
Save