|
|
|
@ -9,6 +9,8 @@ from langchain.llms.base import LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AI21PenaltyData(BaseModel):
|
|
|
|
|
"""Parameters for AI21 penalty data."""
|
|
|
|
|
|
|
|
|
|
scale: int = 0
|
|
|
|
|
applyToWhitespaces: bool = True
|
|
|
|
|
applyToPunctuations: bool = True
|
|
|
|
@ -20,7 +22,8 @@ class AI21PenaltyData(BaseModel):
|
|
|
|
|
class AI21(BaseModel, LLM):
|
|
|
|
|
"""Wrapper around AI21 large language models.
|
|
|
|
|
|
|
|
|
|
To use, you should have the environment variable ``AI21_API_KEY`` set with your API key.
|
|
|
|
|
To use, you should have the environment variable ``AI21_API_KEY``
|
|
|
|
|
set with your API key.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
@ -56,7 +59,7 @@ class AI21(BaseModel, LLM):
|
|
|
|
|
numResults: int = 1
|
|
|
|
|
"""How many completions to generate for each prompt."""
|
|
|
|
|
|
|
|
|
|
logitBias: Dict[str, float] = None
|
|
|
|
|
logitBias: Optional[Dict[str, float]] = None
|
|
|
|
|
"""Adjust the probability of specific tokens being generated."""
|
|
|
|
|
|
|
|
|
|
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
|
|
|
|
@ -123,10 +126,13 @@ class AI21(BaseModel, LLM):
|
|
|
|
|
"prompt": prompt,
|
|
|
|
|
"stopSequences": stop,
|
|
|
|
|
**self._default_params,
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
optional_detail = response.json().get('error')
|
|
|
|
|
raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}')
|
|
|
|
|
response = response.json()
|
|
|
|
|
return response["completions"][0]["data"]["text"]
|
|
|
|
|
optional_detail = response.json().get("error")
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"AI21 /complete call failed with status code {response.status_code}."
|
|
|
|
|
f" Details: {optional_detail}"
|
|
|
|
|
)
|
|
|
|
|
response_json = response.json()
|
|
|
|
|
return response_json["completions"][0]["data"]["text"]
|
|
|
|
|