From fe6695b9e7b91770de7a0e33412faf71f94d9689 Mon Sep 17 00:00:00 2001 From: mrbean <43734688+sam-h-bean@users.noreply.github.com> Date: Sat, 17 Dec 2022 10:00:04 -0500 Subject: [PATCH] Add HuggingFacePipeline LLM (#353) https://github.com/hwchase17/langchain/issues/354 Add support for running your own HF pipeline locally. This would allow you to get a lot more dynamic with what HF features and models you support since you wouldn't be beholden to what is hosted in HF hub. You could also do stuff with HF Optimum to quantize your models and stuff to get pretty fast inference even running on a laptop. --- langchain/__init__.py | 2 + langchain/llms/__init__.py | 11 +- langchain/llms/huggingface_pipeline.py | 118 +++++++++++ poetry.lock | 192 ++++++++++++++---- pyproject.toml | 5 +- .../llms/test_huggingface_pipeline.py | 41 ++++ tests/integration_tests/llms/utils.py | 2 +- 7 files changed, 332 insertions(+), 39 deletions(-) create mode 100644 langchain/llms/huggingface_pipeline.py create mode 100644 tests/integration_tests/llms/test_huggingface_pipeline.py diff --git a/langchain/__init__.py b/langchain/__init__.py index a32ffa95..c100d0e3 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -18,6 +18,7 @@ from langchain.chains import ( ) from langchain.docstore import InMemoryDocstore, Wikipedia from langchain.llms import Cohere, HuggingFaceHub, OpenAI +from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.logger import BaseLogger, StdOutLogger from langchain.prompts import ( BasePromptTemplate, @@ -50,6 +51,7 @@ __all__ = [ "ReActChain", "Wikipedia", "HuggingFaceHub", + "HuggingFacePipeline", "SQLDatabase", "SQLDatabaseChain", "FAISS", diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 01e3c702..f32ef72d 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -5,10 +5,18 @@ from langchain.llms.ai21 import AI21 from langchain.llms.base import LLM from langchain.llms.cohere import Cohere from langchain.llms.huggingface_hub import HuggingFaceHub +from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.nlpcloud import NLPCloud from langchain.llms.openai import OpenAI -__all__ = ["Cohere", "NLPCloud", "OpenAI", "HuggingFaceHub", "AI21"] +__all__ = [ + "Cohere", + "NLPCloud", + "OpenAI", + "HuggingFaceHub", + "HuggingFacePipeline", + "AI21", +] type_to_cls_dict: Dict[str, Type[LLM]] = { "ai21": AI21, @@ -16,4 +24,5 @@ type_to_cls_dict: Dict[str, Type[LLM]] = { "huggingface_hub": HuggingFaceHub, "nlpcloud": NLPCloud, "openai": OpenAI, + "huggingface_pipeline": HuggingFacePipeline, } diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py new file mode 100644 index 00000000..6db9bcd0 --- /dev/null +++ b/langchain/llms/huggingface_pipeline.py @@ -0,0 +1,118 @@ +"""Wrapper around HuggingFace Pipeline APIs.""" +from typing import Any, List, Mapping, Optional + +from pydantic import BaseModel, Extra + +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens + +DEFAULT_MODEL_ID = "gpt2" +DEFAULT_TASK = "text-generation" +VALID_TASKS = ("text2text-generation", "text-generation") + + +class HuggingFacePipeline(LLM, BaseModel): + """Wrapper around HuggingFace Pipeline API. + + To use, you should have the ``transformers`` python package installed. + + Only supports `text-generation` and `text2text-generation` for now. + + Example using from_model_id: + .. code-block:: python + + from langchain.llms.huggingface_pipeline import HuggingFacePipeline + hf = HuggingFacePipeline.from_model_id( + model_id="gpt2", task="text-generation" + ) + Example passing pipeline in directly: + .. code-block:: python + + from langchain.llms.huggingface_pipeline import HuggingFacePipeline + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + + model_id = "gpt2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + pipe = pipeline( + "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 + ) + hf = HuggingFacePipeline(pipeline=pipe + """ + + pipeline: Any #: :meta private: + model_id: str = DEFAULT_MODEL_ID + """Model name to use.""" + model_kwargs: Optional[dict] = None + """Key word arguments to pass to the model.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @classmethod + def from_model_id( + cls, + model_id: str, + task: str, + model_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> LLM: + """Construct the pipeline object from model_id and task.""" + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers import pipeline as hf_pipeline + + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + pipeline = hf_pipeline( + task=task, model=model, tokenizer=tokenizer, **model_kwargs + ) + if pipeline.task not in VALID_TASKS: + raise ValueError( + f"Got invalid task {pipeline.task}, " + f"currently only {VALID_TASKS} are supported" + ) + + return cls( + pipeline=pipeline, + model_id=model_id, + model_kwargs=model_kwargs, + **kwargs, + ) + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please it install it with `pip install transformers`." + ) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + **{"model_id": self.model_id}, + **{"model_kwargs": self.model_kwargs}, + } + + @property + def _llm_type(self) -> str: + return "huggingface_pipeline" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + response = self.pipeline(text_inputs=prompt) + if self.pipeline.task == "text-generation": + # Text generation return includes the starter text. + text = response[0]["generated_text"][len(prompt) :] + elif self.pipeline.task == "text2text-generation": + text = response[0]["generated_text"] + else: + raise ValueError( + f"Got invalid task {self.pipeline.task}, " + f"currently only {VALID_TASKS} are supported" + ) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + return text diff --git a/poetry.lock b/poetry.lock index 71d31f41..9dcbd09b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -640,7 +640,7 @@ arrow = ">=0.15.0" [[package]] name = "isort" -version = "5.11.1" +version = "5.11.2" description = "A Python utility / library to sort Python imports." category = "dev" optional = false @@ -1181,6 +1181,54 @@ category = "main" optional = false python-versions = ">=3.8" +[[package]] +name = "nvidia-cublas-cu11" +version = "11.10.3.66" +description = "CUBLAS native runtime libraries" +category = "main" +optional = true +python-versions = ">=3" + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cuda-nvrtc-cu11" +version = "11.7.99" +description = "NVRTC native runtime libraries" +category = "main" +optional = true +python-versions = ">=3" + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cuda-runtime-cu11" +version = "11.7.99" +description = "CUDA Runtime native Libraries" +category = "main" +optional = true +python-versions = ">=3" + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cudnn-cu11" +version = "8.5.0.96" +description = "cuDNN runtime libraries" +category = "main" +optional = true +python-versions = ">=3" + +[package.dependencies] +setuptools = "*" +wheel = "*" + [[package]] name = "packaging" version = "22.0" @@ -1750,7 +1798,7 @@ python-versions = ">=3.6" [[package]] name = "spacy" -version = "3.4.3" +version = "3.4.4" description = "Industrial-strength Natural Language Processing (NLP) in Python" category = "main" optional = true @@ -1769,6 +1817,7 @@ preshed = ">=3.0.2,<3.1.0" pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<1.11.0" requests = ">=2.13.0,<3.0.0" setuptools = "*" +smart-open = ">=5.2.1,<7.0.0" spacy-legacy = ">=3.0.10,<3.1.0" spacy-loggers = ">=1.0.0,<2.0.0" srsly = ">=2.4.3,<3.0.0" @@ -1997,6 +2046,24 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "torch" +version = "1.13.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = true +python-versions = ">=3.7.0" + +[package.dependencies] +nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""} +nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""} +nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""} +nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""} +typing-extensions = "*" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] + [[package]] name = "tornado" version = "6.2" @@ -2227,6 +2294,17 @@ docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"] optional = ["python-socks", "wsaccel"] test = ["websockets"] +[[package]] +name = "wheel" +version = "0.38.4" +description = "A built-package format for Python" +category = "main" +optional = true +python-versions = ">=3.7" + +[package.extras] +test = ["pytest (>=3.0.0)"] + [[package]] name = "widgetsnbextension" version = "4.0.4" @@ -2260,13 +2338,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken"] -llms = ["manifest-ml"] +all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"] +llms = ["manifest-ml", "torch", "transformers"] [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<4.0" -content-hash = "45b0665b465cf551601e1a18f5b99c9a66c99313de21292f2b897d2ec42d2f6a" +content-hash = "d9ed2c5e1b2c51d7f8a9f74c858ab5058db14b7d6ca542777d7ecd07ccef5ee8" [metadata.files] anyio = [ @@ -2757,8 +2835,8 @@ isoduration = [ {file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"}, ] isort = [ - {file = "isort-5.11.1-py3-none-any.whl", hash = "sha256:bf02c95f1fe615ebbe13a619cfed1619ddfe8941274c9e3de3143adca406cb02"}, - {file = "isort-5.11.1.tar.gz", hash = "sha256:7c5bd998504826b6f1e6f2f98b533976b066baba29b8bae83fdeefd0b89c6b70"}, + {file = "isort-5.11.2-py3-none-any.whl", hash = "sha256:e486966fba83f25b8045f8dd7455b0a0d1e4de481e1d7ce4669902d9fb85e622"}, + {file = "isort-5.11.2.tar.gz", hash = "sha256:dd8bbc5c0990f2a095d754e50360915f73b4c26fc82733eb5bfc6b48396af4d2"}, ] jedi = [ {file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"}, @@ -3084,6 +3162,23 @@ numpy = [ {file = "numpy-1.23.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d"}, {file = "numpy-1.23.5.tar.gz", hash = "sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a"}, ] +nvidia-cublas-cu11 = [ + {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"}, + {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"}, +] +nvidia-cuda-nvrtc-cu11 = [ + {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"}, + {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"}, + {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"}, +] +nvidia-cuda-runtime-cu11 = [ + {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"}, + {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"}, +] +nvidia-cudnn-cu11 = [ + {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, + {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, +] packaging = [ {file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"}, {file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"}, @@ -3622,34 +3717,34 @@ soupsieve = [ {file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"}, ] spacy = [ - {file = "spacy-3.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e546b314f619502ae03e5eb9a0cfd09ca7a9db265bcdd8a3af83cfb0f1432e55"}, - {file = "spacy-3.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ded11aa8966236aab145b4d2d024b3eb61ac50078362d77d9ed7d8c240ef0f4a"}, - {file = "spacy-3.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:462e141f514d78cff85685b5b12eb8cadac0bad2f7820149cbe18d03ccb2e59c"}, - {file = "spacy-3.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c966d25b3f3e49f5de08546b3638928f49678c365cbbebd0eec28f74e0adb539"}, - {file = "spacy-3.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:2ddba486c4c981abe6f1e3fd72648dc8811966e5f0e05808f9c9fab155c388d7"}, - {file = "spacy-3.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c87117dd335fba44d1c0d77602f0763c3addf4e7ef9bdbe9a495466c3484c69"}, - {file = "spacy-3.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ce3938720f48eaeeb360a7f623f15a0d9efd1a688d5d740e3d4cdcd6f6da8a3"}, - {file = "spacy-3.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ad6bf5e4e7f0bc2ef94b7ff6fe59abd766f74c192bca2f17430a3b3cd5bda5a"}, - {file = "spacy-3.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6644c678bd7af567c6ce679f71d64119282e7d6f1a6f787162a91be3ea39333"}, - {file = "spacy-3.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:e6b871de8857a6820140358db3943180fdbe03d44ed792155cee6cb95f4ac4ea"}, - {file = "spacy-3.4.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d211c2b8894354bf8d961af9a9dcab38f764e1dcddd7b80760e438fcd4c9fe43"}, - {file = "spacy-3.4.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ea41f9de30435456235c4182d8bc2eb54a0a64719856e66e780350bb4c8cfbe"}, - {file = "spacy-3.4.3-cp36-cp36m-win_amd64.whl", hash = "sha256:afaf6e716cbac4a0fbfa9e9bf95decff223936597ddd03ea869118a7576aa1b1"}, - {file = "spacy-3.4.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7115da36369b3c537caf2fe08e0b45528bd091c7f56ba3580af1e6fdfa9b1081"}, - {file = "spacy-3.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b3e629c889cac9656151286ec1232c6a948ce0d44a39f1ef5e60fed4f183a10"}, - {file = "spacy-3.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9277cd0fcb96ee5dd885f7e96c639f21afd96198d61ca32100446afbff4dfbef"}, - {file = "spacy-3.4.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a36bd06a5a147350e5f5f6903c4777296c37b18199251bb41056c3a73aa4494f"}, - {file = "spacy-3.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bdafcd0823ca804c39d0bed9e677eb7d0235b1259563d0fd4d3a201c71108af8"}, - {file = "spacy-3.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cdc23a48e6543402b4c56ebf2d36246001175c29fd56d3081efcec684651abc"}, - {file = "spacy-3.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:455c2fbd1de24b6fe34fa121d87525134d7498f9f458ebc8274d7940b473999e"}, - {file = "spacy-3.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c85279fbb6b75d7fb8d7c59c2b734502e51271cad90926e8df1d21b67da5aa"}, - {file = "spacy-3.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:5c0d65f39184f522b4e67b965a42d121a3b2d799362682fe8847b64b0ce5bc7c"}, - {file = "spacy-3.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a7b97ec21ed773edb2479ae5d6c7686b8034f418df6bccd9218f5c3c2b7cf888"}, - {file = "spacy-3.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:36a9a506029842795099fd97ad95f0da2845c319020fcc7164cbf33650726f83"}, - {file = "spacy-3.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ab293eb1423fa05c7ee71b2fedda57c2b4a4ca8dc054ce678809457287b01dc"}, - {file = "spacy-3.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb6d0f185126decc8392cde7d28eb6e85ba4bca15424713288cccc49c2a3c52b"}, - {file = "spacy-3.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:676ab9ab2cf94ba48caa306f185a166e85bd35b388ec24512c8ba7dfcbc7517e"}, - {file = "spacy-3.4.3.tar.gz", hash = "sha256:22698cf5175e2b697e82699fcccee3092b42137a57d352df208d71657fd693bb"}, + {file = "spacy-3.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:07a10999a3e37f896758a92c2eed263638bcbf2747dc3a4aeea929aaa20ea28c"}, + {file = "spacy-3.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e6d98511dc8a88d3a96bcae13971a284459362076738c85053d1a3791f6cde92"}, + {file = "spacy-3.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2cad9c5543f03b3375c252e4dd45670ee8ed99c925dca15eadab5084fd1b033"}, + {file = "spacy-3.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ade19c1e676cac2546f268db22bc5eba08d12beafabe80f1b9f06028b3a0b52"}, + {file = "spacy-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:e782c8a7c4805cc1b34ed2b11f72a5cf2b9851e20f7afe3e97caf206f19f761b"}, + {file = "spacy-3.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aa027e69ef9fe42c8b02b940872e5bde0ce1bf66b6bf488c6493e3ce660c4b3a"}, + {file = "spacy-3.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddeb5d725b6fa9c9009b1ff645db8f5caab9ed8956ee3a84b8379951caad1d36"}, + {file = "spacy-3.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29d6bb428a6bb19e026d8bbb9d4385c25b21e1ce51fcaabadfb5599b2390a79c"}, + {file = "spacy-3.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a21187ad4c44e166dc3deed23992ea1a74d731c9a6bdd9fca306d455181577fa"}, + {file = "spacy-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:10643c6d335a02805f6676738a3e992323cfd9438115cc253435e5053dc93824"}, + {file = "spacy-3.4.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:486228cfa7ced18ec99008388028bd2329262ab8108e7c19252c1a67b2801909"}, + {file = "spacy-3.4.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb7a213178c298b95532075d6dddfb374bbe56ef8d2687212763b4583048da2"}, + {file = "spacy-3.4.4-cp36-cp36m-win_amd64.whl", hash = "sha256:15e5c41d408d1d30d8f3dd8e4eed9ed28e6174e011b8d61c1345981562e2e8f5"}, + {file = "spacy-3.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8979dbd3594c5c268cedad53f456a3ec3a0a2b78a1199788aacedcd68eef3a00"}, + {file = "spacy-3.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f4736fea2630e696422dfe38bfb3d0a7864bc6a9072d6e49a906af46870e36e"}, + {file = "spacy-3.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498bf01e8c7ab601c3f8d6c51497817b40a3322a3967c032536b18ce9ea26d0a"}, + {file = "spacy-3.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:95f880c6fea57d51c448ad84f96d79d8758e5e18bdbaaee060c15af11641079b"}, + {file = "spacy-3.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ccbede9be470c5d795168bf3be41fc86e18892a9247a742b394ba866c005391"}, + {file = "spacy-3.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2f1edbecfde9c11b17e87768bb5f2c33948fb1e3bf54b2197031ff9053607277"}, + {file = "spacy-3.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66eaf4764e95699934cbd8f38717b283db185c896cfd3d1fb1ad5c6552e8b3c9"}, + {file = "spacy-3.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bb7d53f1a780bb8cc1b27a81e02e8b9bc71abb959f4dc13c21af4041fdd2c7a"}, + {file = "spacy-3.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:c1a5ce5c9b19cdfb4469079e710e72bb09c3cab855f21ef6a614b84c765e0311"}, + {file = "spacy-3.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f7044dca3542579ea1e3ac6cdd821640c2f65dd0c56230688f36e15aca1b8217"}, + {file = "spacy-3.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8a495b0fc00910fb5c1fbe64fdbfe1d3c11b09f421d1ae4e30cdb4c2388a91e4"}, + {file = "spacy-3.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31e9a637960b60c1bb7a36a187271425717e97c14e9d1df613dc4efeffefcbec"}, + {file = "spacy-3.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f9449ffadef85b048c9735ee235da5dca9d0a87038dba6d4ed20c5188e0f13"}, + {file = "spacy-3.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:1b7791a6c0592615b0566001596cc48c72325d1b97e46e574c91bff34f4e3f4c"}, + {file = "spacy-3.4.4.tar.gz", hash = "sha256:e500cf2cb5f1849461a7928fa269703756069bdfb71559065240af6d0208b08c"}, ] spacy-legacy = [ {file = "spacy-legacy-3.0.10.tar.gz", hash = "sha256:16104595d8ab1b7267f817a449ad1f986eb1f2a2edf1050748f08739a479679a"}, @@ -3837,6 +3932,29 @@ tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +torch = [ + {file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"}, + {file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"}, + {file = "torch-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76"}, + {file = "torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566"}, + {file = "torch-1.13.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5"}, + {file = "torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b"}, + {file = "torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11"}, + {file = "torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419"}, + {file = "torch-1.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93"}, + {file = "torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380"}, + {file = "torch-1.13.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28"}, + {file = "torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd"}, + {file = "torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe"}, + {file = "torch-1.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea"}, + {file = "torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312"}, + {file = "torch-1.13.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b"}, + {file = "torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038"}, + {file = "torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc"}, + {file = "torch-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7"}, + {file = "torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf"}, + {file = "torch-1.13.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949"}, +] tornado = [ {file = "tornado-6.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:20f638fd8cc85f3cbae3c732326e96addff0a15e22d80f049e00121651e82e72"}, {file = "tornado-6.2-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:87dcafae3e884462f90c90ecc200defe5e580a7fbbb4365eda7c7c1eb809ebc9"}, @@ -3914,6 +4032,10 @@ websocket-client = [ {file = "websocket-client-1.4.2.tar.gz", hash = "sha256:d6e8f90ca8e2dd4e8027c4561adeb9456b54044312dba655e7cae652ceb9ae59"}, {file = "websocket_client-1.4.2-py3-none-any.whl", hash = "sha256:d6b06432f184438d99ac1f456eaf22fe1ade524c3dd16e661142dc54e9cba574"}, ] +wheel = [ + {file = "wheel-0.38.4-py3-none-any.whl", hash = "sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8"}, + {file = "wheel-0.38.4.tar.gz", hash = "sha256:965f5259b566725405b05e7cf774052044b1ed30119b5d586b2703aafe8719ac"}, +] widgetsnbextension = [ {file = "widgetsnbextension-4.0.4-py3-none-any.whl", hash = "sha256:fa0e840719ec95dd2ec85c3a48913f1a0c29d323eacbcdb0b29bfed0cc6da678"}, {file = "widgetsnbextension-4.0.4.tar.gz", hash = "sha256:44c69f18237af0f610557d6c1c7ef76853f5856a0e604c0a517f2320566bb775"}, diff --git a/pyproject.toml b/pyproject.toml index be4705dc..26866076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ spacy = {version = "^3", optional = true} nltk = {version = "^3", optional = true} transformers = {version = "^4", optional = true} beautifulsoup4 = {version = "^4", optional = true} +torch = {version = "^1.13.1", optional = true} tiktoken = {version = "^0", optional = true, python="^3.9"} [tool.poetry.group.test.dependencies] @@ -49,8 +50,8 @@ jupyter = "^1.0.0" playwright = "^1.28.0" [tool.poetry.extras] -llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml"] -all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken"] +llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] +all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"] [tool.isort] profile = "black" diff --git a/tests/integration_tests/llms/test_huggingface_pipeline.py b/tests/integration_tests/llms/test_huggingface_pipeline.py new file mode 100644 index 00000000..7cf3b6d1 --- /dev/null +++ b/tests/integration_tests/llms/test_huggingface_pipeline.py @@ -0,0 +1,41 @@ +"""Test HuggingFace Pipeline wrapper.""" + +from pathlib import Path + +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + +from langchain.llms.huggingface_pipeline import HuggingFacePipeline +from langchain.llms.loading import load_llm +from tests.integration_tests.llms.utils import assert_llm_equality + + +def test_huggingface_pipeline_text_generation() -> None: + """Test valid call to HuggingFace text generation model.""" + llm = HuggingFacePipeline.from_model_id( + model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10} + ) + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_saving_loading_llm(tmp_path: Path) -> None: + """Test saving/loading an HuggingFaceHub LLM.""" + llm = HuggingFacePipeline.from_model_id( + model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10} + ) + llm.save(file_path=tmp_path / "hf.yaml") + loaded_llm = load_llm(tmp_path / "hf.yaml") + assert_llm_equality(llm, loaded_llm) + + +def test_init_with_pipeline() -> None: + """Test initialization with a HF pipeline.""" + model_id = "gpt2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + pipe = pipeline( + "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 + ) + llm = HuggingFacePipeline(pipeline=pipe) + output = llm("Say foo:") + assert isinstance(output, str) diff --git a/tests/integration_tests/llms/utils.py b/tests/integration_tests/llms/utils.py index 6047e979..c05445d4 100644 --- a/tests/integration_tests/llms/utils.py +++ b/tests/integration_tests/llms/utils.py @@ -10,7 +10,7 @@ def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None: # Client field can be session based, so hash is different despite # all other values being the same, so just assess all other fields for field in llm.__fields__.keys(): - if field != "client": + if field != "client" and field != "pipeline": val = getattr(llm, field) new_val = getattr(loaded_llm, field) assert new_val == val