Harrison/lintai21 (#114)

harrison/prompt_examples
Harrison Chase 2 years ago committed by GitHub
parent d8734ce5ad
commit 9f878e43d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,8 @@ from langchain.llms.base import LLM
class AI21PenaltyData(BaseModel):
"""Parameters for AI21 penalty data."""
scale: int = 0
applyToWhitespaces: bool = True
applyToPunctuations: bool = True
@ -20,7 +22,8 @@ class AI21PenaltyData(BaseModel):
class AI21(BaseModel, LLM):
"""Wrapper around AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY`` set with your API key.
To use, you should have the environment variable ``AI21_API_KEY``
set with your API key.
Example:
.. code-block:: python
@ -56,7 +59,7 @@ class AI21(BaseModel, LLM):
numResults: int = 1
"""How many completions to generate for each prompt."""
logitBias: Dict[str, float] = None
logitBias: Optional[Dict[str, float]] = None
"""Adjust the probability of specific tokens being generated."""
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
@ -123,10 +126,13 @@ class AI21(BaseModel, LLM):
"prompt": prompt,
"stopSequences": stop,
**self._default_params,
}
},
)
if response.status_code != 200:
optional_detail = response.json().get('error')
raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}')
response = response.json()
return response["completions"][0]["data"]["text"]
optional_detail = response.json().get("error")
raise ValueError(
f"AI21 /complete call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
response_json = response.json()
return response_json["completions"][0]["data"]["text"]

@ -14,7 +14,7 @@ setup(
version=__version__,
packages=find_packages(),
description="Building applications with LLMs through composability",
install_requires=["pydantic", "sqlalchemy", "numpy"],
install_requires=["pydantic", "sqlalchemy", "numpy", "requests"],
long_description=long_description,
license="MIT",
url="https://github.com/hwchase17/langchain",

@ -8,3 +8,4 @@ isort
mypy
flake8
flake8-docstrings
types-requests

Loading…
Cancel
Save