Add HuggingFacePipeline LLM (#353)

https://github.com/hwchase17/langchain/issues/354

Add support for running your own HF pipeline locally. This would allow
you to get a lot more dynamic with what HF features and models you
support since you wouldn't be beholden to what is hosted in HF hub. You
could also do stuff with HF Optimum to quantize your models and stuff to
get pretty fast inference even running on a laptop.
harrison/agent_multi_inputs^2
mrbean 1 year ago committed by GitHub
parent 2eef76ed3f
commit fe6695b9e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,7 @@ from langchain.chains import (
)
from langchain.docstore import InMemoryDocstore, Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.logger import BaseLogger, StdOutLogger
from langchain.prompts import (
BasePromptTemplate,
@ -50,6 +51,7 @@ __all__ = [
"ReActChain",
"Wikipedia",
"HuggingFaceHub",
"HuggingFacePipeline",
"SQLDatabase",
"SQLDatabaseChain",
"FAISS",

@ -5,10 +5,18 @@ from langchain.llms.ai21 import AI21
from langchain.llms.base import LLM
from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import OpenAI
__all__ = ["Cohere", "NLPCloud", "OpenAI", "HuggingFaceHub", "AI21"]
__all__ = [
"Cohere",
"NLPCloud",
"OpenAI",
"HuggingFaceHub",
"HuggingFacePipeline",
"AI21",
]
type_to_cls_dict: Dict[str, Type[LLM]] = {
"ai21": AI21,
@ -16,4 +24,5 @@ type_to_cls_dict: Dict[str, Type[LLM]] = {
"huggingface_hub": HuggingFaceHub,
"nlpcloud": NLPCloud,
"openai": OpenAI,
"huggingface_pipeline": HuggingFacePipeline,
}

@ -0,0 +1,118 @@
"""Wrapper around HuggingFace Pipeline APIs."""
from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation")
class HuggingFacePipeline(LLM, BaseModel):
"""Wrapper around HuggingFace Pipeline API.
To use, you should have the ``transformers`` python package installed.
Only supports `text-generation` and `text2text-generation` for now.
Example using from_model_id:
.. code-block:: python
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
hf = HuggingFacePipeline.from_model_id(
model_id="gpt2", task="text-generation"
)
Example passing pipeline in directly:
.. code-block:: python
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
)
hf = HuggingFacePipeline(pipeline=pipe
"""
pipeline: Any #: :meta private:
model_id: str = DEFAULT_MODEL_ID
"""Model name to use."""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@classmethod
def from_model_id(
cls,
model_id: str,
task: str,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> LLM:
"""Construct the pipeline object from model_id and task."""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline as hf_pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
pipeline = hf_pipeline(
task=task, model=model, tokenizer=tokenizer, **model_kwargs
)
if pipeline.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=model_kwargs,
**kwargs,
)
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"model_id": self.model_id},
**{"model_kwargs": self.model_kwargs},
}
@property
def _llm_type(self) -> str:
return "huggingface_pipeline"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response = self.pipeline(text_inputs=prompt)
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

192
poetry.lock generated

@ -640,7 +640,7 @@ arrow = ">=0.15.0"
[[package]]
name = "isort"
version = "5.11.1"
version = "5.11.2"
description = "A Python utility / library to sort Python imports."
category = "dev"
optional = false
@ -1181,6 +1181,54 @@ category = "main"
optional = false
python-versions = ">=3.8"
[[package]]
name = "nvidia-cublas-cu11"
version = "11.10.3.66"
description = "CUBLAS native runtime libraries"
category = "main"
optional = true
python-versions = ">=3"
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "nvidia-cuda-nvrtc-cu11"
version = "11.7.99"
description = "NVRTC native runtime libraries"
category = "main"
optional = true
python-versions = ">=3"
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "nvidia-cuda-runtime-cu11"
version = "11.7.99"
description = "CUDA Runtime native Libraries"
category = "main"
optional = true
python-versions = ">=3"
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "nvidia-cudnn-cu11"
version = "8.5.0.96"
description = "cuDNN runtime libraries"
category = "main"
optional = true
python-versions = ">=3"
[package.dependencies]
setuptools = "*"
wheel = "*"
[[package]]
name = "packaging"
version = "22.0"
@ -1750,7 +1798,7 @@ python-versions = ">=3.6"
[[package]]
name = "spacy"
version = "3.4.3"
version = "3.4.4"
description = "Industrial-strength Natural Language Processing (NLP) in Python"
category = "main"
optional = true
@ -1769,6 +1817,7 @@ preshed = ">=3.0.2,<3.1.0"
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<1.11.0"
requests = ">=2.13.0,<3.0.0"
setuptools = "*"
smart-open = ">=5.2.1,<7.0.0"
spacy-legacy = ">=3.0.10,<3.1.0"
spacy-loggers = ">=1.0.0,<2.0.0"
srsly = ">=2.4.3,<3.0.0"
@ -1997,6 +2046,24 @@ category = "dev"
optional = false
python-versions = ">=3.7"
[[package]]
name = "torch"
version = "1.13.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
category = "main"
optional = true
python-versions = ">=3.7.0"
[package.dependencies]
nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""}
nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""}
typing-extensions = "*"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
[[package]]
name = "tornado"
version = "6.2"
@ -2227,6 +2294,17 @@ docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"]
optional = ["python-socks", "wsaccel"]
test = ["websockets"]
[[package]]
name = "wheel"
version = "0.38.4"
description = "A built-package format for Python"
category = "main"
optional = true
python-versions = ">=3.7"
[package.extras]
test = ["pytest (>=3.0.0)"]
[[package]]
name = "widgetsnbextension"
version = "4.0.4"
@ -2260,13 +2338,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[extras]
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken"]
llms = ["manifest-ml"]
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"]
llms = ["manifest-ml", "torch", "transformers"]
[metadata]
lock-version = "1.1"
python-versions = ">=3.8.1,<4.0"
content-hash = "45b0665b465cf551601e1a18f5b99c9a66c99313de21292f2b897d2ec42d2f6a"
content-hash = "d9ed2c5e1b2c51d7f8a9f74c858ab5058db14b7d6ca542777d7ecd07ccef5ee8"
[metadata.files]
anyio = [
@ -2757,8 +2835,8 @@ isoduration = [
{file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"},
]
isort = [
{file = "isort-5.11.1-py3-none-any.whl", hash = "sha256:bf02c95f1fe615ebbe13a619cfed1619ddfe8941274c9e3de3143adca406cb02"},
{file = "isort-5.11.1.tar.gz", hash = "sha256:7c5bd998504826b6f1e6f2f98b533976b066baba29b8bae83fdeefd0b89c6b70"},
{file = "isort-5.11.2-py3-none-any.whl", hash = "sha256:e486966fba83f25b8045f8dd7455b0a0d1e4de481e1d7ce4669902d9fb85e622"},
{file = "isort-5.11.2.tar.gz", hash = "sha256:dd8bbc5c0990f2a095d754e50360915f73b4c26fc82733eb5bfc6b48396af4d2"},
]
jedi = [
{file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"},
@ -3084,6 +3162,23 @@ numpy = [
{file = "numpy-1.23.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d"},
{file = "numpy-1.23.5.tar.gz", hash = "sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a"},
]
nvidia-cublas-cu11 = [
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"},
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"},
]
nvidia-cuda-nvrtc-cu11 = [
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"},
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"},
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"},
]
nvidia-cuda-runtime-cu11 = [
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"},
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"},
]
nvidia-cudnn-cu11 = [
{file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"},
{file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"},
]
packaging = [
{file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"},
{file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"},
@ -3622,34 +3717,34 @@ soupsieve = [
{file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"},
]
spacy = [
{file = "spacy-3.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e546b314f619502ae03e5eb9a0cfd09ca7a9db265bcdd8a3af83cfb0f1432e55"},
{file = "spacy-3.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ded11aa8966236aab145b4d2d024b3eb61ac50078362d77d9ed7d8c240ef0f4a"},
{file = "spacy-3.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:462e141f514d78cff85685b5b12eb8cadac0bad2f7820149cbe18d03ccb2e59c"},
{file = "spacy-3.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c966d25b3f3e49f5de08546b3638928f49678c365cbbebd0eec28f74e0adb539"},
{file = "spacy-3.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:2ddba486c4c981abe6f1e3fd72648dc8811966e5f0e05808f9c9fab155c388d7"},
{file = "spacy-3.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c87117dd335fba44d1c0d77602f0763c3addf4e7ef9bdbe9a495466c3484c69"},
{file = "spacy-3.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ce3938720f48eaeeb360a7f623f15a0d9efd1a688d5d740e3d4cdcd6f6da8a3"},
{file = "spacy-3.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ad6bf5e4e7f0bc2ef94b7ff6fe59abd766f74c192bca2f17430a3b3cd5bda5a"},
{file = "spacy-3.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6644c678bd7af567c6ce679f71d64119282e7d6f1a6f787162a91be3ea39333"},
{file = "spacy-3.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:e6b871de8857a6820140358db3943180fdbe03d44ed792155cee6cb95f4ac4ea"},
{file = "spacy-3.4.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d211c2b8894354bf8d961af9a9dcab38f764e1dcddd7b80760e438fcd4c9fe43"},
{file = "spacy-3.4.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ea41f9de30435456235c4182d8bc2eb54a0a64719856e66e780350bb4c8cfbe"},
{file = "spacy-3.4.3-cp36-cp36m-win_amd64.whl", hash = "sha256:afaf6e716cbac4a0fbfa9e9bf95decff223936597ddd03ea869118a7576aa1b1"},
{file = "spacy-3.4.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7115da36369b3c537caf2fe08e0b45528bd091c7f56ba3580af1e6fdfa9b1081"},
{file = "spacy-3.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b3e629c889cac9656151286ec1232c6a948ce0d44a39f1ef5e60fed4f183a10"},
{file = "spacy-3.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9277cd0fcb96ee5dd885f7e96c639f21afd96198d61ca32100446afbff4dfbef"},
{file = "spacy-3.4.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a36bd06a5a147350e5f5f6903c4777296c37b18199251bb41056c3a73aa4494f"},
{file = "spacy-3.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bdafcd0823ca804c39d0bed9e677eb7d0235b1259563d0fd4d3a201c71108af8"},
{file = "spacy-3.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cdc23a48e6543402b4c56ebf2d36246001175c29fd56d3081efcec684651abc"},
{file = "spacy-3.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:455c2fbd1de24b6fe34fa121d87525134d7498f9f458ebc8274d7940b473999e"},
{file = "spacy-3.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c85279fbb6b75d7fb8d7c59c2b734502e51271cad90926e8df1d21b67da5aa"},
{file = "spacy-3.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:5c0d65f39184f522b4e67b965a42d121a3b2d799362682fe8847b64b0ce5bc7c"},
{file = "spacy-3.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a7b97ec21ed773edb2479ae5d6c7686b8034f418df6bccd9218f5c3c2b7cf888"},
{file = "spacy-3.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:36a9a506029842795099fd97ad95f0da2845c319020fcc7164cbf33650726f83"},
{file = "spacy-3.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ab293eb1423fa05c7ee71b2fedda57c2b4a4ca8dc054ce678809457287b01dc"},
{file = "spacy-3.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb6d0f185126decc8392cde7d28eb6e85ba4bca15424713288cccc49c2a3c52b"},
{file = "spacy-3.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:676ab9ab2cf94ba48caa306f185a166e85bd35b388ec24512c8ba7dfcbc7517e"},
{file = "spacy-3.4.3.tar.gz", hash = "sha256:22698cf5175e2b697e82699fcccee3092b42137a57d352df208d71657fd693bb"},
{file = "spacy-3.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:07a10999a3e37f896758a92c2eed263638bcbf2747dc3a4aeea929aaa20ea28c"},
{file = "spacy-3.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e6d98511dc8a88d3a96bcae13971a284459362076738c85053d1a3791f6cde92"},
{file = "spacy-3.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2cad9c5543f03b3375c252e4dd45670ee8ed99c925dca15eadab5084fd1b033"},
{file = "spacy-3.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ade19c1e676cac2546f268db22bc5eba08d12beafabe80f1b9f06028b3a0b52"},
{file = "spacy-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:e782c8a7c4805cc1b34ed2b11f72a5cf2b9851e20f7afe3e97caf206f19f761b"},
{file = "spacy-3.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aa027e69ef9fe42c8b02b940872e5bde0ce1bf66b6bf488c6493e3ce660c4b3a"},
{file = "spacy-3.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddeb5d725b6fa9c9009b1ff645db8f5caab9ed8956ee3a84b8379951caad1d36"},
{file = "spacy-3.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29d6bb428a6bb19e026d8bbb9d4385c25b21e1ce51fcaabadfb5599b2390a79c"},
{file = "spacy-3.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a21187ad4c44e166dc3deed23992ea1a74d731c9a6bdd9fca306d455181577fa"},
{file = "spacy-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:10643c6d335a02805f6676738a3e992323cfd9438115cc253435e5053dc93824"},
{file = "spacy-3.4.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:486228cfa7ced18ec99008388028bd2329262ab8108e7c19252c1a67b2801909"},
{file = "spacy-3.4.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb7a213178c298b95532075d6dddfb374bbe56ef8d2687212763b4583048da2"},
{file = "spacy-3.4.4-cp36-cp36m-win_amd64.whl", hash = "sha256:15e5c41d408d1d30d8f3dd8e4eed9ed28e6174e011b8d61c1345981562e2e8f5"},
{file = "spacy-3.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8979dbd3594c5c268cedad53f456a3ec3a0a2b78a1199788aacedcd68eef3a00"},
{file = "spacy-3.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f4736fea2630e696422dfe38bfb3d0a7864bc6a9072d6e49a906af46870e36e"},
{file = "spacy-3.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498bf01e8c7ab601c3f8d6c51497817b40a3322a3967c032536b18ce9ea26d0a"},
{file = "spacy-3.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:95f880c6fea57d51c448ad84f96d79d8758e5e18bdbaaee060c15af11641079b"},
{file = "spacy-3.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ccbede9be470c5d795168bf3be41fc86e18892a9247a742b394ba866c005391"},
{file = "spacy-3.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2f1edbecfde9c11b17e87768bb5f2c33948fb1e3bf54b2197031ff9053607277"},
{file = "spacy-3.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66eaf4764e95699934cbd8f38717b283db185c896cfd3d1fb1ad5c6552e8b3c9"},
{file = "spacy-3.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bb7d53f1a780bb8cc1b27a81e02e8b9bc71abb959f4dc13c21af4041fdd2c7a"},
{file = "spacy-3.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:c1a5ce5c9b19cdfb4469079e710e72bb09c3cab855f21ef6a614b84c765e0311"},
{file = "spacy-3.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f7044dca3542579ea1e3ac6cdd821640c2f65dd0c56230688f36e15aca1b8217"},
{file = "spacy-3.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8a495b0fc00910fb5c1fbe64fdbfe1d3c11b09f421d1ae4e30cdb4c2388a91e4"},
{file = "spacy-3.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31e9a637960b60c1bb7a36a187271425717e97c14e9d1df613dc4efeffefcbec"},
{file = "spacy-3.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f9449ffadef85b048c9735ee235da5dca9d0a87038dba6d4ed20c5188e0f13"},
{file = "spacy-3.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:1b7791a6c0592615b0566001596cc48c72325d1b97e46e574c91bff34f4e3f4c"},
{file = "spacy-3.4.4.tar.gz", hash = "sha256:e500cf2cb5f1849461a7928fa269703756069bdfb71559065240af6d0208b08c"},
]
spacy-legacy = [
{file = "spacy-legacy-3.0.10.tar.gz", hash = "sha256:16104595d8ab1b7267f817a449ad1f986eb1f2a2edf1050748f08739a479679a"},
@ -3837,6 +3932,29 @@ tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
torch = [
{file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"},
{file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"},
{file = "torch-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76"},
{file = "torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566"},
{file = "torch-1.13.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5"},
{file = "torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b"},
{file = "torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11"},
{file = "torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419"},
{file = "torch-1.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93"},
{file = "torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380"},
{file = "torch-1.13.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28"},
{file = "torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd"},
{file = "torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe"},
{file = "torch-1.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea"},
{file = "torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312"},
{file = "torch-1.13.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b"},
{file = "torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038"},
{file = "torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc"},
{file = "torch-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7"},
{file = "torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf"},
{file = "torch-1.13.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949"},
]
tornado = [
{file = "tornado-6.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:20f638fd8cc85f3cbae3c732326e96addff0a15e22d80f049e00121651e82e72"},
{file = "tornado-6.2-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:87dcafae3e884462f90c90ecc200defe5e580a7fbbb4365eda7c7c1eb809ebc9"},
@ -3914,6 +4032,10 @@ websocket-client = [
{file = "websocket-client-1.4.2.tar.gz", hash = "sha256:d6e8f90ca8e2dd4e8027c4561adeb9456b54044312dba655e7cae652ceb9ae59"},
{file = "websocket_client-1.4.2-py3-none-any.whl", hash = "sha256:d6b06432f184438d99ac1f456eaf22fe1ade524c3dd16e661142dc54e9cba574"},
]
wheel = [
{file = "wheel-0.38.4-py3-none-any.whl", hash = "sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8"},
{file = "wheel-0.38.4.tar.gz", hash = "sha256:965f5259b566725405b05e7cf774052044b1ed30119b5d586b2703aafe8719ac"},
]
widgetsnbextension = [
{file = "widgetsnbextension-4.0.4-py3-none-any.whl", hash = "sha256:fa0e840719ec95dd2ec85c3a48913f1a0c29d323eacbcdb0b29bfed0cc6da678"},
{file = "widgetsnbextension-4.0.4.tar.gz", hash = "sha256:44c69f18237af0f610557d6c1c7ef76853f5856a0e604c0a517f2320566bb775"},

@ -22,6 +22,7 @@ spacy = {version = "^3", optional = true}
nltk = {version = "^3", optional = true}
transformers = {version = "^4", optional = true}
beautifulsoup4 = {version = "^4", optional = true}
torch = {version = "^1.13.1", optional = true}
tiktoken = {version = "^0", optional = true, python="^3.9"}
[tool.poetry.group.test.dependencies]
@ -49,8 +50,8 @@ jupyter = "^1.0.0"
playwright = "^1.28.0"
[tool.poetry.extras]
llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml"]
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken"]
llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"]
[tool.isort]
profile = "black"

@ -0,0 +1,41 @@
"""Test HuggingFace Pipeline wrapper."""
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
def test_huggingface_pipeline_text_generation() -> None:
"""Test valid call to HuggingFace text generation model."""
llm = HuggingFacePipeline.from_model_id(
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFacePipeline.from_model_id(
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
)
llm.save(file_path=tmp_path / "hf.yaml")
loaded_llm = load_llm(tmp_path / "hf.yaml")
assert_llm_equality(llm, loaded_llm)
def test_init_with_pipeline() -> None:
"""Test initialization with a HF pipeline."""
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
)
llm = HuggingFacePipeline(pipeline=pipe)
output = llm("Say foo:")
assert isinstance(output, str)

@ -10,7 +10,7 @@ def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
# Client field can be session based, so hash is different despite
# all other values being the same, so just assess all other fields
for field in llm.__fields__.keys():
if field != "client":
if field != "client" and field != "pipeline":
val = getattr(llm, field)
new_val = getattr(loaded_llm, field)
assert new_val == val

Loading…
Cancel
Save