Community: Fuse HuggingFace Endpoint-related classes into one (#17254)

## Description
Fuse HuggingFace Endpoint-related classes into one:
-
[HuggingFaceHub](5ceaf784f3/libs/community/langchain_community/llms/huggingface_hub.py)
-
[HuggingFaceTextGenInference](5ceaf784f3/libs/community/langchain_community/llms/huggingface_text_gen_inference.py)
- and
[HuggingFaceEndpoint](5ceaf784f3/libs/community/langchain_community/llms/huggingface_endpoint.py)

Are fused into
- HuggingFaceEndpoint

## Issue
The deduplication of classes was creating a lack of clarity, and
additional effort to develop classes leads to issues like [this
hack](5ceaf784f3/libs/community/langchain_community/llms/huggingface_endpoint.py (L159)).

## Dependancies

None, this removes dependancies.

## Twitter handle

If you want to post about this: @AymericRoucher

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/17750/head
Aymeric Roucher 3 months ago committed by GitHub
parent 8009be862e
commit 0d294760e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,238 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Huggingface Endpoints\n",
"\n",
">The [Hugging Face Hub](https://huggingface.co/docs/hub/index) is a platform with over 120k models, 20k datasets, and 50k demo apps (Spaces), all open source and publicly available, in an online platform where people can easily collaborate and build ML together.\n",
"\n",
"The `Hugging Face Hub` also offers various endpoints to build ML applications.\n",
"This example showcases how to connect to the different Endpoints types.\n",
"\n",
"In particular, text generation inference is powered by [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a custom-built Rust, Python and gRPC server for blazing-faset text generation inference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import HuggingFaceEndpoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installation and Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use, you should have the ``huggingface_hub`` python [package installed](https://huggingface.co/docs/huggingface_hub/installation)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get a token: https://huggingface.co/docs/api-inference/quicktour#get-your-api-token\n",
"\n",
"from getpass import getpass\n",
"\n",
"HUGGINGFACEHUB_API_TOKEN = getpass()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = HUGGINGFACEHUB_API_TOKEN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import HuggingFaceEndpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"Who won the FIFA World Cup in the year 1994? \"\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate.from_template(template)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Examples\n",
"\n",
"Here is an example of how you can access `HuggingFaceEndpoint` integration of the free [Serverless Endpoints](https://huggingface.co/inference-endpoints/serverless) API."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"repo_id = \"mistralai/Mistral-7B-Instruct-v0.2\"\n",
"\n",
"llm = HuggingFaceEndpoint(\n",
" repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dedicated Endpoint\n",
"\n",
"\n",
"The free serverless API lets you implement solutions and iterate in no time, but it may be rate limited for heavy use cases, since the loads are shared with other requests.\n",
"\n",
"For enterprise workloads, the best is to use [Inference Endpoints - Dedicated](https://huggingface.co/inference-endpoints/dedicated).\n",
"This gives access to a fully managed infrastructure that offer more flexibility and speed. These resoucres come with continuous support and uptime guarantees, as well as options like AutoScaling\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set the url to your Inference Endpoint below\n",
"your_endpoint_url = \"https://fayjubiy2xqn36z0.us-east-1.aws.endpoints.huggingface.cloud\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceEndpoint(\n",
" endpoint_url=f\"{your_endpoint_url}\",\n",
" max_new_tokens=512,\n",
" top_k=10,\n",
" top_p=0.95,\n",
" typical_p=0.95,\n",
" temperature=0.01,\n",
" repetition_penalty=1.03,\n",
")\n",
"llm(\"What did foo say about bar?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from langchain_community.llms import HuggingFaceEndpoint\n",
"\n",
"llm = HuggingFaceEndpoint(\n",
" endpoint_url=f\"{your_endpoint_url}\",\n",
" max_new_tokens=512,\n",
" top_k=10,\n",
" top_p=0.95,\n",
" typical_p=0.95,\n",
" temperature=0.01,\n",
" repetition_penalty=1.03,\n",
" streaming=True,\n",
")\n",
"llm(\"What did foo say about bar?\", callbacks=[StreamingStdOutCallbackHandler()])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "agents",
"language": "python",
"name": "agents"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -1,466 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "959300d4",
"metadata": {},
"source": [
"# Hugging Face Hub\n",
"\n",
">The [Hugging Face Hub](https://huggingface.co/docs/hub/index) is a platform with over 120k models, 20k datasets, and 50k demo apps (Spaces), all open source and publicly available, in an online platform where people can easily collaborate and build ML together.\n",
"\n",
"This example showcases how to connect to the `Hugging Face Hub` and use different models."
]
},
{
"cell_type": "markdown",
"id": "1ddafc6d-7d7c-48fa-838f-0e7f50895ce3",
"metadata": {},
"source": [
"## Installation and Setup"
]
},
{
"cell_type": "markdown",
"id": "4c1b8450-5eaf-4d34-8341-2d785448a1ff",
"metadata": {
"tags": []
},
"source": [
"To use, you should have the ``huggingface_hub`` python [package installed](https://huggingface.co/docs/huggingface_hub/installation)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d772b637-de00-4663-bd77-9bc96d798db2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%pip install --upgrade --quiet huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d597a792-354c-4ca5-b483-5965eec5d63d",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [
"# get a token: https://huggingface.co/docs/api-inference/quicktour#get-your-api-token\n",
"\n",
"from getpass import getpass\n",
"\n",
"HUGGINGFACEHUB_API_TOKEN = getpass()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b8c5b88c-e4b8-4d0d-9a35-6e8f106452c2",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = HUGGINGFACEHUB_API_TOKEN"
]
},
{
"cell_type": "markdown",
"id": "84dd44c1-c428-41f3-a911-520281386c94",
"metadata": {},
"source": [
"## Prepare Examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fe7d1d1-241d-426a-acff-e208f1088871",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import HuggingFaceHub"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6620f39b-3d32-4840-8931-ff7d2c3e47e8",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "44adc1a0-9c0a-4f1e-af5a-fe04222e78d7",
"metadata": {},
"outputs": [],
"source": [
"question = \"Who won the FIFA World Cup in the year 1994? \"\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate.from_template(template)"
]
},
{
"cell_type": "markdown",
"id": "ddaa06cf-95ec-48ce-b0ab-d892a7909693",
"metadata": {},
"source": [
"## Examples\n",
"\n",
"Below are some examples of models you can access through the `Hugging Face Hub` integration."
]
},
{
"cell_type": "markdown",
"id": "4c16fded-70d1-42af-8bfa-6ddda9f0bc63",
"metadata": {},
"source": [
"### `Flan`, by `Google`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "39c7eeac-01c4-486b-9480-e828a9e73e78",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"repo_id = \"google/flan-t5-xxl\" # See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads for some other options"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3acf0069",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The FIFA World Cup was held in the year 1994. West Germany won the FIFA World Cup in 1994\n"
]
}
],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "1a5c97af-89bc-4e59-95c1-223742a9160b",
"metadata": {},
"source": [
"### `Dolly`, by `Databricks`\n",
"\n",
"See [Databricks](https://huggingface.co/databricks) organization page for a list of available models."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "521fcd2b-8e38-4920-b407-5c7d330411c9",
"metadata": {},
"outputs": [],
"source": [
"repo_id = \"databricks/dolly-v2-3b\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9907ec3a-fe0c-4543-81c4-d42f9453f16c",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" First of all, the world cup was won by the Germany. Then the Argentina won the world cup in 2022. So, the Argentina won the world cup in 1994.\n",
"\n",
"\n",
"Question: Who\n"
]
}
],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "03f6ae52-b5f9-4de6-832c-551cb3fa11ae",
"metadata": {},
"source": [
"### `Camel`, by `Writer`\n",
"\n",
"See [Writer's](https://huggingface.co/Writer) organization page for a list of available models."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "257a091d-750b-4910-ac08-fe1c7b3fd98b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"repo_id = \"Writer/camel-5b-hf\" # See https://huggingface.co/Writer for other options"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b06f6838-a11a-4d6a-88e3-91fa1747a2b3",
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "2bf838eb-1083-402f-b099-b07c452418c8",
"metadata": {},
"source": [
"### `XGen`, by `Salesforce`\n",
"\n",
"See [more information](https://github.com/salesforce/xgen)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "18c78880-65d7-41d0-9722-18090efb60e9",
"metadata": {},
"outputs": [],
"source": [
"repo_id = \"Salesforce/xgen-7b-8k-base\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b1150b4-ec30-4674-849e-6a41b085aa2b",
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "0aca9f9e-f333-449c-97b2-10d1dbf17e75",
"metadata": {},
"source": [
"### `Falcon`, by `Technology Innovation Institute (TII)`\n",
"\n",
"See [more information](https://huggingface.co/tiiuae/falcon-40b)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "496b35ac-5ee2-4b68-a6ce-232608f56c03",
"metadata": {},
"outputs": [],
"source": [
"repo_id = \"tiiuae/falcon-40b\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff2541ad-e394-4179-93c2-7ae9c4ca2a25",
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "7e15849b-5561-4bb9-86ec-6412ca10196a",
"metadata": {},
"source": [
"### `InternLM-Chat`, by `Shanghai AI Laboratory`\n",
"\n",
"See [more information](https://huggingface.co/internlm/internlm-7b)."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3b533461-59f8-406e-907b-000841fa60a7",
"metadata": {},
"outputs": [],
"source": [
"repo_id = \"internlm/internlm-chat-7b\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c71210b9-5895-41a2-889a-f430d22fa1aa",
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"max_length\": 128, \"temperature\": 0.8}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "4f2e5132-1713-42d7-919a-8c313744ce95",
"metadata": {},
"source": [
"### `Qwen`, by `Alibaba Cloud`\n",
"\n",
">`Tongyi Qianwen-7B` (`Qwen-7B`) is a model with a scale of 7 billion parameters in the `Tongyi Qianwen` large model series developed by `Alibaba Cloud`. `Qwen-7B` is a large language model based on Transformer, which is trained on ultra-large-scale pre-training data.\n",
"\n",
"See [more information on HuggingFace](https://huggingface.co/Qwen/Qwen-7B) of on [GitHub](https://github.com/QwenLM/Qwen-7B).\n",
"\n",
"See here a [big example for LangChain integration and Qwen](https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f598b1ca-77c7-40f1-a83f-c21ea9910c88",
"metadata": {},
"outputs": [],
"source": [
"repo_id = \"Qwen/Qwen-7B\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c97f4e2-d401-44fb-9da7-b60b2e2cc663",
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"max_length\": 128, \"temperature\": 0.5}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"id": "e3871376-ed0e-49a8-8d9b-7e60dbbd2b35",
"metadata": {},
"source": [
"### `Yi` series models, by `01.ai`\n",
"\n",
">The `Yi` series models are large language models trained from scratch by developers at [01.ai](https://01.ai/). The first public release contains two bilingual(English/Chinese) base models with the parameter sizes of 6B(`Yi-6B`) and 34B(`Yi-34B`). Both of them are trained with 4K sequence length and can be extended to 32K during inference time. The `Yi-6B-200K` and `Yi-34B-200K` are base model with 200K context length.\n",
"\n",
"Here we test the [Yi-34B](https://huggingface.co/01-ai/Yi-34B) model."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1c9d3125-3f50-48b8-93b6-b50847207afa",
"metadata": {},
"outputs": [],
"source": [
"repo_id = \"01-ai/Yi-34B\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b661069-8229-4850-9f13-c4ca28c0c96b",
"metadata": {},
"outputs": [],
"source": [
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"max_length\": 128, \"temperature\": 0.5}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd6f3edc-9f97-47a6-ab2c-116756babbe6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,108 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Huggingface TextGen Inference\n",
"\n",
"[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co/) to power LLMs api-inference widgets.\n",
"\n",
"This notebooks goes over how to use a self hosted LLM using `Text Generation Inference`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use, you should have the `text_generation` python package installed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# !pip3 install text_generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import HuggingFaceTextGenInference\n",
"\n",
"llm = HuggingFaceTextGenInference(\n",
" inference_server_url=\"http://localhost:8010/\",\n",
" max_new_tokens=512,\n",
" top_k=10,\n",
" top_p=0.95,\n",
" typical_p=0.95,\n",
" temperature=0.01,\n",
" repetition_penalty=1.03,\n",
")\n",
"llm(\"What did foo say about bar?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from langchain_community.llms import HuggingFaceTextGenInference\n",
"\n",
"llm = HuggingFaceTextGenInference(\n",
" inference_server_url=\"http://localhost:8010/\",\n",
" max_new_tokens=512,\n",
" top_k=10,\n",
" top_p=0.95,\n",
" typical_p=0.95,\n",
" temperature=0.01,\n",
" repetition_penalty=1.03,\n",
" streaming=True,\n",
")\n",
"llm(\"What did foo say about bar?\", callbacks=[StreamingStdOutCallbackHandler()])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -1,5 +1,13 @@
{
"redirects": [
{
"source": "/docs/integrations/llms/huggingface_textgen_inference",
"destination": "/docs/integrations/llms/huggingface_endpoint"
},
{
"source": "/docs/integrations/llms/huggingface_hub",
"destination": "/docs/integrations/llms/huggingface_endpoint"
},
{
"source": "/docs/integrations/llms/watsonxllm",
"destination": "/docs/integrations/llms/ibm_watsonx"

@ -1,4 +1,5 @@
"""Hugging Face Chat Wrapper."""
from typing import Any, List, Optional, Union
from langchain_core.callbacks.manager import (
@ -52,6 +53,7 @@ class ChatHuggingFace(BaseChatModel):
from transformers import AutoTokenizer
self._resolve_model_id()
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
@ -90,10 +92,10 @@ class ChatHuggingFace(BaseChatModel):
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("at least one HumanMessage must be provided")
raise ValueError("At least one HumanMessage must be provided!")
if not isinstance(messages[-1], HumanMessage):
raise ValueError("last message must be a HumanMessage")
raise ValueError("Last message must be a HumanMessage!")
messages_dicts = [self._to_chatml_format(m) for m in messages]
@ -135,20 +137,15 @@ class ChatHuggingFace(BaseChatModel):
from huggingface_hub import list_inference_endpoints
available_endpoints = list_inference_endpoints("*")
if isinstance(self.llm, HuggingFaceTextGenInference):
endpoint_url = self.llm.inference_server_url
elif isinstance(self.llm, HuggingFaceEndpoint):
endpoint_url = self.llm.endpoint_url
elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM
if isinstance(self.llm, HuggingFaceHub) or (
hasattr(self.llm, "repo_id") and self.llm.repo_id
):
self.model_id = self.llm.repo_id
return
elif isinstance(self.llm, HuggingFaceTextGenInference):
endpoint_url: Optional[str] = self.llm.inference_server_url
else:
raise ValueError(f"Unknown LLM type: {type(self.llm)}")
endpoint_url = self.llm.endpoint_url
for endpoint in available_endpoints:
if endpoint.url == endpoint_url:
@ -156,8 +153,8 @@ class ChatHuggingFace(BaseChatModel):
if not self.model_id:
raise ValueError(
"Failed to resolve model_id"
f"Could not find model id for inference server provided: {endpoint_url}"
"Failed to resolve model_id:"
f"Could not find model id for inference server: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint."
)

@ -1,12 +1,17 @@
from typing import Any, Dict, List, Mapping, Optional
import json
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
VALID_TASKS = (
"text2text-generation",
@ -17,70 +22,198 @@ VALID_TASKS = (
class HuggingFaceEndpoint(LLM):
"""HuggingFace Endpoint models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
"""
HuggingFace Endpoint.
Only supports `text-generation` and `text2text-generation` for now.
To use this class, you should have installed the ``huggingface_hub`` package, and
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
or given as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_community.llms import HuggingFaceEndpoint
endpoint_url = (
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
# Basic Example (no streaming)
llm = HuggingFaceEndpoint(
endpoint_url="http://localhost:8010/",
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
huggingfacehub_api_token="my-api-key"
)
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
print(llm("What is Deep Learning?"))
# Streaming response example
from langchain_community.callbacks import streaming_stdout
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
llm = HuggingFaceEndpoint(
endpoint_url="http://localhost:8010/",
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
callbacks=callbacks,
streaming=True,
huggingfacehub_api_token="my-api-key"
)
print(llm("What is Deep Learning?"))
"""
endpoint_url: str = ""
endpoint_url: Optional[str] = None
"""Endpoint URL to use."""
repo_id: Optional[str] = None
"""Repo to use."""
huggingfacehub_api_token: Optional[str] = None
max_new_tokens: int = 512
"""Maximum number of generated tokens"""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for
top-k-filtering."""
top_p: Optional[float] = 0.95
"""If set to < 1, only the smallest set of most probable tokens with probabilities
that add up to `top_p` or higher are kept for generation."""
typical_p: Optional[float] = 0.95
"""Typical Decoding mass. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information."""
temperature: Optional[float] = 0.8
"""The value used to module the logits distribution."""
repetition_penalty: Optional[float] = None
"""The parameter for repetition penalty. 1.0 means no penalty.
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
return_full_text: bool = False
"""Whether to prepend the prompt to the generated text"""
truncate: Optional[int] = None
"""Truncate inputs tokens to the given size"""
stop_sequences: List[str] = Field(default_factory=list)
"""Stop generating tokens if a member of `stop_sequences` is generated"""
seed: Optional[int] = None
"""Random sampling seed"""
inference_server_url: str = ""
"""text-generation-inference instance base url"""
timeout: int = 120
"""Timeout in seconds"""
streaming: bool = False
"""Whether to generate a stream of tokens asynchronously"""
do_sample: bool = False
"""Activate logits sampling"""
watermark: bool = False
"""Watermarking with [A Watermark for Large Language Models]
(https://arxiv.org/abs/2301.10226)"""
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any text-generation-inference server parameters not explicitly specified"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
model: str
client: Any
async_client: Any
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
huggingfacehub_api_token: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please make sure that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
if "endpoint_url" not in values and "repo_id" not in values:
raise ValueError(
"Please specify an `endpoint_url` or `repo_id` for the model."
)
if "endpoint_url" in values and "repo_id" in values:
raise ValueError(
"Please specify either an `endpoint_url` OR a `repo_id`, not both."
)
values["model"] = values.get("endpoint_url") or values.get("repo_id")
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
"""Validate that package is installed and that the API token is valid."""
try:
from huggingface_hub.hf_api import HfApi
try:
HfApi(
endpoint="https://huggingface.co", # Can be a Private Hub endpoint.
token=huggingfacehub_api_token,
).whoami()
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
from huggingface_hub import login
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
values["huggingfacehub_api_token"] = huggingfacehub_api_token
try:
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
login(token=huggingfacehub_api_token)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
from huggingface_hub import AsyncInferenceClient, InferenceClient
values["client"] = InferenceClient(
model=values["model"],
timeout=values["timeout"],
token=huggingfacehub_api_token,
**values["server_kwargs"],
)
values["async_client"] = AsyncInferenceClient(
model=values["model"],
timeout=values["timeout"],
token=huggingfacehub_api_token,
**values["server_kwargs"],
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling text generation inference API."""
return {
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"return_full_text": self.return_full_text,
"truncate": self.truncate,
"stop_sequences": self.stop_sequences,
"seed": self.seed,
"do_sample": self.do_sample,
"watermark": self.watermark,
**self.model_kwargs,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
@ -95,6 +228,13 @@ class HuggingFaceEndpoint(LLM):
"""Return type of llm."""
return "huggingface_endpoint"
def _invocation_params(
self, runtime_stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
return params
def _call(
self,
prompt: str,
@ -102,62 +242,129 @@ class HuggingFaceEndpoint(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
"""Call out to HuggingFace Hub's inference endpoint."""
invocation_params = self._invocation_params(stop, **kwargs)
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **invocation_params):
completion += chunk.text
return completion
else:
invocation_params["stop"] = invocation_params[
"stop_sequences"
] # porting 'stop_sequences' into the 'stop' argument
response = self.client.post(
json={"inputs": prompt, "parameters": invocation_params},
stream=False,
task=self.task,
)
response_text = json.loads(response.decode())[0]["generated_text"]
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
Returns:
The string generated by the model.
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
invocation_params = self._invocation_params(stop, **kwargs)
if self.streaming:
completion = ""
async for chunk in self._astream(
prompt, stop, run_manager, **invocation_params
):
completion += chunk.text
return completion
else:
invocation_params["stop"] = invocation_params["stop_sequences"]
response = await self.async_client.post(
json={"inputs": prompt, "parameters": invocation_params},
stream=False,
task=self.task,
)
response_text = json.loads(response.decode())[0]["generated_text"]
Example:
.. code-block:: python
# Maybe the generation has stopped at one of the stop sequences:
# then remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
# payload samples
params = {**_model_kwargs, **kwargs}
parameter_payload = {"inputs": prompt, "parameters": params}
for response in self.client.text_generation(
prompt, **invocation_params, stream=True
):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in response:
stop_seq_found = stop_seq
# HTTP headers for authorization
headers = {
"Authorization": f"Bearer {self.huggingfacehub_api_token}",
"Content-Type": "application/json",
}
# identify text to yield
text: Optional[str] = None
if stop_seq_found:
text = response[: response.index(stop_seq_found)]
else:
text = response
# send request
try:
response = requests.post(
self.endpoint_url, headers=headers, json=parameter_payload
)
except requests.exceptions.RequestException as e: # This is the correct syntax
raise ValueError(f"Error raised by inference endpoint: {e}")
generated_text = response.json()
if "error" in generated_text:
raise ValueError(
f"Error raised by inference API: {generated_text['error']}"
)
if self.task == "text-generation":
text = generated_text[0]["generated_text"]
# Remove prompt if included in generated text.
if text.startswith(prompt):
text = text[len(prompt) :]
elif self.task == "text2text-generation":
text = generated_text[0]["generated_text"]
elif self.task == "summarization":
text = generated_text[0]["summary_text"]
elif self.task == "conversational":
text = generated_text["response"][1]
else:
raise ValueError(
f"Got invalid task {self.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if stop_seq_found:
break
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
async for response in await self.async_client.text_generation(
prompt, **invocation_params, stream=True
):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in response:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if stop_seq_found:
text = response[: response.index(stop_seq_found)]
else:
text = response
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if stop_seq_found:
break

@ -1,6 +1,7 @@
import json
from typing import Any, Dict, List, Mapping, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
@ -19,8 +20,10 @@ VALID_TASKS_DICT = {
}
@deprecated("0.0.21", removal="0.2.0", alternative="HuggingFaceEndpoint")
class HuggingFaceHub(LLM):
"""HuggingFaceHub models.
! This class is deprecated, you should use HuggingFaceEndpoint instead.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass

@ -9,8 +9,6 @@ from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import Extra
from langchain_community.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
@ -201,7 +199,12 @@ class HuggingFacePipeline(BaseLLM):
batch_prompts = prompts[i : i + self.batch_size]
# Process batch of prompts
responses = self.pipeline(batch_prompts, **pipeline_kwargs)
responses = self.pipeline(
batch_prompts,
stop_sequence=stop,
return_full_text=False,
**pipeline_kwargs,
)
# Process each response in the batch
for j, response in enumerate(responses):
@ -210,23 +213,7 @@ class HuggingFacePipeline(BaseLLM):
response = response[0]
if self.pipeline.task == "text-generation":
try:
from transformers.pipelines.text_generation import ReturnType
remove_prompt = (
self.pipeline._postprocess_params.get("return_type")
!= ReturnType.NEW_TEXT
)
except Exception as e:
logger.warning(
f"Unable to extract pipeline return_type. "
f"Received error:\n\n{e}"
)
remove_prompt = True
if remove_prompt:
text = response["generated_text"][len(batch_prompts[j]) :]
else:
text = response["generated_text"]
text = response["generated_text"]
elif self.pipeline.task == "text2text-generation":
text = response["generated_text"]
elif self.pipeline.task == "summarization":
@ -236,9 +223,6 @@ class HuggingFacePipeline(BaseLLM):
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop:
# Enforce stop tokens
text = enforce_stop_tokens(text, stop)
# Append the processed text to results
text_generations.append(text)

@ -1,6 +1,7 @@
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -13,9 +14,11 @@ from langchain_core.utils import get_pydantic_field_names
logger = logging.getLogger(__name__)
@deprecated("0.0.21", removal="0.2.0", alternative="HuggingFaceEndpoint")
class HuggingFaceTextGenInference(LLM):
"""
HuggingFace text generation API.
! This class is deprecated, you should use HuggingFaceEndpoint instead !
To use, you should have the `text-generation` python package installed and
a text-generation server running.

@ -1,6 +1,5 @@
"""Test HuggingFace API wrapper."""
"""Test HuggingFace Endpoints."""
import unittest
from pathlib import Path
import pytest
@ -10,51 +9,73 @@ from langchain_community.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
@unittest.skip(
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
)
def test_huggingface_endpoint_text_generation() -> None:
"""Test valid call to HuggingFace text generation model."""
def test_huggingface_endpoint_call_error() -> None:
"""Test valid call to HuggingFace that errors."""
llm = HuggingFaceEndpoint(endpoint_url="", model_kwargs={"max_new_tokens": -1})
with pytest.raises(ValueError):
llm("Say foo:")
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFaceEndpoint(
endpoint_url="", task="text-generation", model_kwargs={"max_new_tokens": 10}
)
llm.save(file_path=tmp_path / "hf.yaml")
loaded_llm = load_llm(tmp_path / "hf.yaml")
assert_llm_equality(llm, loaded_llm)
def test_huggingface_text_generation() -> None:
"""Test valid call to HuggingFace text generation model."""
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
output = llm("Say foo:")
print(output) # noqa: T201
assert isinstance(output, str)
@unittest.skip(
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
)
def test_huggingface_endpoint_text2text_generation() -> None:
def test_huggingface_text2text_generation() -> None:
"""Test valid call to HuggingFace text2text model."""
llm = HuggingFaceEndpoint(endpoint_url="", task="text2text-generation")
llm = HuggingFaceEndpoint(repo_id="google/flan-t5-xl")
output = llm("The capital of New York is")
assert output == "Albany"
@unittest.skip(
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
)
def test_huggingface_endpoint_summarization() -> None:
def test_huggingface_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFaceEndpoint(endpoint_url="", task="summarization")
llm = HuggingFaceEndpoint(repo_id="facebook/bart-large-cnn")
output = llm("Say foo:")
assert isinstance(output, str)
def test_huggingface_endpoint_call_error() -> None:
def test_huggingface_call_error() -> None:
"""Test valid call to HuggingFace that errors."""
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": -1})
with pytest.raises(ValueError):
llm("Say foo:")
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFaceEndpoint(
endpoint_url="", task="text-generation", model_kwargs={"max_new_tokens": 10}
)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceEndpoint LLM."""
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
llm.save(file_path=tmp_path / "hf.yaml")
loaded_llm = load_llm(tmp_path / "hf.yaml")
assert_llm_equality(llm, loaded_llm)
def test_invocation_params_stop_sequences() -> None:
llm = HuggingFaceEndpoint()
assert llm._default_params["stop_sequences"] == []
runtime_stop = None
assert llm._invocation_params(runtime_stop)["stop_sequences"] == []
assert llm._default_params["stop_sequences"] == []
runtime_stop = ["stop"]
assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"]
assert llm._default_params["stop_sequences"] == []
llm = HuggingFaceEndpoint(stop_sequences=["."])
runtime_stop = ["stop"]
assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"]
assert llm._default_params["stop_sequences"] == ["."]

Loading…
Cancel
Save