diff --git a/README.md b/README.md index 61c53c0e..4a5020f3 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,9 @@ The following use cases require specific installs and environment variables: - *Cohere*: - Install requirements with `pip install cohere` - Set the following environment variable: `COHERE_API_KEY` +- *HuggingFace Hub* + - Install requirements with `pip install huggingface_hub` + - Set the following environment variable: `HUGGINGFACEHUB_API_TOKEN` - *SerpAPI*: - Install requirements with `pip install google-search-results` - Set the following environment variable: `SERPAPI_API_KEY` diff --git a/examples/huggingface_hub.ipynb b/examples/huggingface_hub.ipynb new file mode 100644 index 00000000..a9169b3c --- /dev/null +++ b/examples/huggingface_hub.ipynb @@ -0,0 +1,66 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3acf0069", + "metadata": {}, + "outputs": [ + { + "ename": "ValidationError", + "evalue": "1 validation error for HuggingFaceHub\n__root__\n Did not find HuggingFace API token, please add an environment variable `HUGGINGFACEHUB_API_TOKEN` which contains it. (type=value_error)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/y6/8_bzdg295ld6s1_97_12m4lr0000gn/T/ipykernel_56760/1512947828.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m Answer: Let's think step by step.\"\"\"\n\u001b[1;32m 6\u001b[0m \u001b[0mprompt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPrompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemplate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtemplate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_variables\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"question\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mllm_chain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLLMChain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprompt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mllm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mHuggingFaceHub\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrepo_id\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"gpt2\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtemperature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mquestion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/workplace/langchain/.venv/lib/python3.7/site-packages/pydantic/main.cpython-37m-darwin.so\u001b[0m in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mValidationError\u001b[0m: 1 validation error for HuggingFaceHub\n__root__\n Did not find HuggingFace API token, please add an environment variable `HUGGINGFACEHUB_API_TOKEN` which contains it. (type=value_error)" + ] + } + ], + "source": [ + "from langchain import Prompt, HuggingFaceHub, LLMChain\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = Prompt(template=template, input_variables=[\"question\"])\n", + "llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"gpt2\", temperature=1e-10))\n", + "\n", + "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", + "\n", + "llm_chain.predict(question=question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae4559c7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index 8f3733fe..c5912311 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -12,7 +12,7 @@ from langchain.chains import ( SelfAskWithSearchChain, SerpAPIChain, ) -from langchain.llms import Cohere, OpenAI +from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.prompt import Prompt __all__ = [ @@ -24,4 +24,5 @@ __all__ = [ "Cohere", "OpenAI", "Prompt", + "HuggingFaceHub", ] diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index f8765718..3d83e5e7 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -1,5 +1,6 @@ """Wrappers on top of large language models APIs.""" from langchain.llms.cohere import Cohere +from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.openai import OpenAI -__all__ = ["Cohere", "OpenAI"] +__all__ = ["Cohere", "OpenAI", "HuggingFaceHub"] diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 4394a6dc..cf6b9d9b 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -5,14 +5,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM - - -def remove_stop_tokens(text: str, stop: List[str]) -> str: - """Remove stop tokens, should they occur at end.""" - for s in stop: - if text.endswith(s): - return text[: -len(s)] - return text +from langchain.llms.utils import enforce_stop_tokens class Cohere(BaseModel, LLM): @@ -104,5 +97,5 @@ class Cohere(BaseModel, LLM): # If stop tokens are provided, Cohere's endpoint returns them. # In order to make this consistent with other endpoints, we strip them. if stop is not None: - text = remove_stop_tokens(text, stop) + text = enforce_stop_tokens(text, stop) return text diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py new file mode 100644 index 00000000..0f73e61e --- /dev/null +++ b/langchain/llms/huggingface_hub.py @@ -0,0 +1,102 @@ +"""Wrapper around HuggingFace APIs.""" +import os +from typing import Any, Dict, List, Mapping, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens + +DEFAULT_REPO_ID = "gpt2" + + +class HuggingFaceHub(BaseModel, LLM): + """Wrapper around HuggingFaceHub models. + + To use, you should have the ``huggingface_hub`` python package installed, and the + environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token. + + Only supports task `text-generation` for now. + + Example: + .. code-block:: python + + from langchain import HuggingFace + hf = HuggingFace(model="text-davinci-002") + """ + + client: Any #: :meta private: + repo_id: str = DEFAULT_REPO_ID + """Model name to use.""" + temperature: float = 0.7 + """What sampling temperature to use.""" + max_new_tokens: int = 200 + """The maximum number of tokens to generate in the completion.""" + top_p: int = 1 + """Total probability mass of tokens to consider at each step.""" + num_return_sequences: int = 1 + """How many completions to generate for each prompt.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "HUGGINGFACEHUB_API_TOKEN" not in os.environ: + raise ValueError( + "Did not find HuggingFace API token, please add an environment variable" + " `HUGGINGFACEHUB_API_TOKEN` which contains it." + ) + try: + from huggingface_hub.inference_api import InferenceApi + + repo_id = values.get("repo_id", DEFAULT_REPO_ID) + values["client"] = InferenceApi( + repo_id=repo_id, + token=os.environ["HUGGINGFACEHUB_API_TOKEN"], + task="text-generation", + ) + except ImportError: + raise ValueError( + "Could not import huggingface_hub python package. " + "Please it install it with `pip install huggingface_hub`." + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling HuggingFace Hub API.""" + return { + "temperature": self.temperature, + "max_new_tokens": self.max_new_tokens, + "top_p": self.top_p, + "num_return_sequences": self.num_return_sequences, + } + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to HuggingFace Hub's inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = hf("Tell me a joke.") + """ + response = self.client(inputs=prompt, params=self._default_params) + if "error" in response: + raise ValueError(f"Error raised by inference API: {response['error']}") + text = response[0]["generated_text"][len(prompt) :] + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + return text diff --git a/langchain/llms/utils.py b/langchain/llms/utils.py new file mode 100644 index 00000000..a42fd130 --- /dev/null +++ b/langchain/llms/utils.py @@ -0,0 +1,8 @@ +"""Common utility functions for working with LLM APIs.""" +import re +from typing import List + + +def enforce_stop_tokens(text: str, stop: List[str]) -> str: + """Cut off the text as soon as any stop words occur.""" + return re.split("|".join(stop), text)[0] diff --git a/requirements.txt b/requirements.txt index c3429fff..4fedac87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,15 @@ -r test_requirements.txt +# For linting black isort mypy flake8 flake8-docstrings +# For integrations cohere openai google-search-results playwright +huggingface_hub +# For development +jupyter diff --git a/tests/integration_tests/llms/test_huggingface_hub.py b/tests/integration_tests/llms/test_huggingface_hub.py new file mode 100644 index 00000000..7583b59c --- /dev/null +++ b/tests/integration_tests/llms/test_huggingface_hub.py @@ -0,0 +1,19 @@ +"""Test HuggingFace API wrapper.""" + +import pytest + +from langchain.llms.huggingface_hub import HuggingFaceHub + + +def test_huggingface_call() -> None: + """Test valid call to HuggingFace.""" + llm = HuggingFaceHub(max_new_tokens=10) + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_huggingface_call_error() -> None: + """Test valid call to HuggingFace that errors.""" + llm = HuggingFaceHub(max_new_tokens=-1) + with pytest.raises(ValueError): + llm("Say foo:") diff --git a/tests/unit_tests/llms/test_cohere.py b/tests/unit_tests/llms/test_cohere.py deleted file mode 100644 index 9e30c333..00000000 --- a/tests/unit_tests/llms/test_cohere.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Test helper functions for Cohere API.""" - -from langchain.llms.cohere import remove_stop_tokens - - -def test_remove_stop_tokens() -> None: - """Test removing stop tokens when they occur.""" - text = "foo bar baz" - output = remove_stop_tokens(text, ["moo", "baz"]) - assert output == "foo bar " - - -def test_remove_stop_tokens_none() -> None: - """Test removing stop tokens when they do not occur.""" - text = "foo bar baz" - output = remove_stop_tokens(text, ["moo"]) - assert output == "foo bar baz" diff --git a/tests/unit_tests/llms/test_utils.py b/tests/unit_tests/llms/test_utils.py new file mode 100644 index 00000000..5685f650 --- /dev/null +++ b/tests/unit_tests/llms/test_utils.py @@ -0,0 +1,19 @@ +"""Test LLM utility functions.""" +from langchain.llms.utils import enforce_stop_tokens + + +def test_enforce_stop_tokens() -> None: + """Test removing stop tokens when they occur.""" + text = "foo bar baz" + output = enforce_stop_tokens(text, ["moo", "baz"]) + assert output == "foo bar " + text = "foo bar baz" + output = enforce_stop_tokens(text, ["moo", "baz", "bar"]) + assert output == "foo " + + +def test_enforce_stop_tokens_none() -> None: + """Test removing stop tokens when they do not occur.""" + text = "foo bar baz" + output = enforce_stop_tokens(text, ["moo"]) + assert output == "foo bar baz"