Harrison/lintai21 (#114)

pull/115/head^2
Harrison Chase 2 years ago committed by GitHub
parent d8734ce5ad
commit 9f878e43d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,8 @@ from langchain.llms.base import LLM
class AI21PenaltyData(BaseModel): class AI21PenaltyData(BaseModel):
"""Parameters for AI21 penalty data."""
scale: int = 0 scale: int = 0
applyToWhitespaces: bool = True applyToWhitespaces: bool = True
applyToPunctuations: bool = True applyToPunctuations: bool = True
@ -20,7 +22,8 @@ class AI21PenaltyData(BaseModel):
class AI21(BaseModel, LLM): class AI21(BaseModel, LLM):
"""Wrapper around AI21 large language models. """Wrapper around AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY`` set with your API key. To use, you should have the environment variable ``AI21_API_KEY``
set with your API key.
Example: Example:
.. code-block:: python .. code-block:: python
@ -56,7 +59,7 @@ class AI21(BaseModel, LLM):
numResults: int = 1 numResults: int = 1
"""How many completions to generate for each prompt.""" """How many completions to generate for each prompt."""
logitBias: Dict[str, float] = None logitBias: Optional[Dict[str, float]] = None
"""Adjust the probability of specific tokens being generated.""" """Adjust the probability of specific tokens being generated."""
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY") ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
@ -123,10 +126,13 @@ class AI21(BaseModel, LLM):
"prompt": prompt, "prompt": prompt,
"stopSequences": stop, "stopSequences": stop,
**self._default_params, **self._default_params,
} },
) )
if response.status_code != 200: if response.status_code != 200:
optional_detail = response.json().get('error') optional_detail = response.json().get("error")
raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}') raise ValueError(
response = response.json() f"AI21 /complete call failed with status code {response.status_code}."
return response["completions"][0]["data"]["text"] f" Details: {optional_detail}"
)
response_json = response.json()
return response_json["completions"][0]["data"]["text"]

@ -14,7 +14,7 @@ setup(
version=__version__, version=__version__,
packages=find_packages(), packages=find_packages(),
description="Building applications with LLMs through composability", description="Building applications with LLMs through composability",
install_requires=["pydantic", "sqlalchemy", "numpy"], install_requires=["pydantic", "sqlalchemy", "numpy", "requests"],
long_description=long_description, long_description=long_description,
license="MIT", license="MIT",
url="https://github.com/hwchase17/langchain", url="https://github.com/hwchase17/langchain",

@ -8,3 +8,4 @@ isort
mypy mypy
flake8 flake8
flake8-docstrings flake8-docstrings
types-requests

Loading…
Cancel
Save