From 9f878e43d85f16f040d7f1a54d85bc5fdb36f833 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 10 Nov 2022 08:46:35 -0800 Subject: [PATCH] Harrison/lintai21 (#114) --- langchain/llms/ai21.py | 20 +++++++++++++------- setup.py | 2 +- test_requirements.txt | 1 + 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 591c3d6b..b1dac08b 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -9,6 +9,8 @@ from langchain.llms.base import LLM class AI21PenaltyData(BaseModel): + """Parameters for AI21 penalty data.""" + scale: int = 0 applyToWhitespaces: bool = True applyToPunctuations: bool = True @@ -20,7 +22,8 @@ class AI21PenaltyData(BaseModel): class AI21(BaseModel, LLM): """Wrapper around AI21 large language models. - To use, you should have the environment variable ``AI21_API_KEY`` set with your API key. + To use, you should have the environment variable ``AI21_API_KEY`` + set with your API key. Example: .. code-block:: python @@ -56,7 +59,7 @@ class AI21(BaseModel, LLM): numResults: int = 1 """How many completions to generate for each prompt.""" - logitBias: Dict[str, float] = None + logitBias: Optional[Dict[str, float]] = None """Adjust the probability of specific tokens being generated.""" ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY") @@ -123,10 +126,13 @@ class AI21(BaseModel, LLM): "prompt": prompt, "stopSequences": stop, **self._default_params, - } + }, ) if response.status_code != 200: - optional_detail = response.json().get('error') - raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}') - response = response.json() - return response["completions"][0]["data"]["text"] + optional_detail = response.json().get("error") + raise ValueError( + f"AI21 /complete call failed with status code {response.status_code}." + f" Details: {optional_detail}" + ) + response_json = response.json() + return response_json["completions"][0]["data"]["text"] diff --git a/setup.py b/setup.py index 68972b59..cdda6f9c 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( version=__version__, packages=find_packages(), description="Building applications with LLMs through composability", - install_requires=["pydantic", "sqlalchemy", "numpy"], + install_requires=["pydantic", "sqlalchemy", "numpy", "requests"], long_description=long_description, license="MIT", url="https://github.com/hwchase17/langchain", diff --git a/test_requirements.txt b/test_requirements.txt index 55e36b04..4318012c 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -8,3 +8,4 @@ isort mypy flake8 flake8-docstrings +types-requests