You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/llms/aleph_alpha.py

237 lines
9.0 KiB
Python

"""Wrapper around Aleph Alpha APIs."""
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class AlephAlpha(LLM, BaseModel):
"""Wrapper around Aleph Alpha large language models.
To use, you should have the ``aleph_alpha_client`` python package installed, and the
environment variable ``ALEPH_ALPHA_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Parameters are explained more in depth here:
https://github.com/Aleph-Alpha/aleph-alpha-client/blob/c14b7dd2b4325c7da0d6a119f6e76385800e097b/aleph_alpha_client/completion.py#L10
Example:
.. code-block:: python
from langchain.llms import AlephAlpha
alpeh_alpha = AlephAlpha(aleph_alpha_api_key="my-api-key")
"""
client: Any #: :meta private:
model: Optional[str] = "luminous-base"
"""Model name to use."""
maximum_tokens: int = 64
"""The maximum number of tokens to be generated."""
temperature: float = 0.0
"""A non-negative float that tunes the degree of randomness in generation."""
top_k: int = 0
"""Number of most likely tokens to consider at each step."""
top_p: float = 0.0
"""Total probability mass of tokens to consider at each step."""
presence_penalty: float = 0.0
"""Penalizes repeated tokens."""
frequency_penalty: float = 0.0
"""Penalizes repeated tokens according to frequency."""
repetition_penalties_include_prompt: Optional[bool] = False
"""Flag deciding whether presence penalty or frequency penalty are
updated from the prompt."""
use_multiplicative_presence_penalty: Optional[bool] = False
"""Flag deciding whether presence penalty is applied
multiplicatively (True) or additively (False)."""
penalty_bias: Optional[str] = None
"""Penalty bias for the completion."""
penalty_exceptions: Optional[List[str]] = None
"""List of strings that may be generated without penalty,
regardless of other penalty settings"""
penalty_exceptions_include_stop_sequences: Optional[bool] = None
"""Should stop_sequences be included in penalty_exceptions."""
best_of: Optional[int] = None
"""returns the one with the "best of" results
(highest log probability per token)
"""
n: int = 1
"""How many completions to generate for each prompt."""
logit_bias: Optional[Dict[int, float]] = None
"""The logit bias allows to influence the likelihood of generating tokens."""
log_probs: Optional[int] = None
"""Number of top log probabilities to be returned for each generated token."""
tokens: Optional[bool] = False
"""return tokens of completion."""
disable_optimizations: Optional[bool] = False
minimum_tokens: Optional[int] = 0
"""Generate at least this number of tokens."""
echo: bool = False
"""Echo the prompt in the completion."""
use_multiplicative_frequency_penalty: bool = False
sequence_penalty: float = 0.0
sequence_penalty_min_length: int = 2
use_multiplicative_sequence_penalty: bool = False
completion_bias_inclusion: Optional[Sequence[str]] = None
completion_bias_inclusion_first_token_only: bool = False
completion_bias_exclusion: Optional[Sequence[str]] = None
completion_bias_exclusion_first_token_only: bool = False
"""Only consider the first token for the completion_bias_exclusion."""
contextual_control_threshold: Optional[float] = None
"""If set to None, attention control parameters only apply to those tokens that have
explicitly been set in the request.
If set to a non-None value, control parameters are also applied to similar tokens.
"""
control_log_additive: Optional[bool] = True
"""True: apply control by adding the log(control_factor) to attention scores.
False: (attention_scores - - attention_scores.min(-1)) * control_factor
"""
repetition_penalties_include_completion: bool = True
"""Flag deciding whether presence penalty or frequency penalty
are updated from the completion."""
raw_completion: bool = False
"""Force the raw completion of the model to be returned."""
aleph_alpha_api_key: Optional[str] = None
"""API key for Aleph Alpha API."""
stop_sequences: Optional[List[str]] = None
"""Stop sequences to use."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aleph_alpha_api_key = get_from_dict_or_env(
values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY"
)
try:
import aleph_alpha_client
values["client"] = aleph_alpha_client.Client(token=aleph_alpha_api_key)
except ImportError:
raise ValueError(
"Could not import aleph_alpha_client python package. "
"Please it install it with `pip install aleph_alpha_client`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling the Aleph Alpha API."""
return {
"maximum_tokens": self.maximum_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"n": self.n,
"repetition_penalties_include_prompt": self.repetition_penalties_include_prompt, # noqa: E501
"use_multiplicative_presence_penalty": self.use_multiplicative_presence_penalty, # noqa: E501
"penalty_bias": self.penalty_bias,
"penalty_exceptions": self.penalty_exceptions,
"penalty_exceptions_include_stop_sequences": self.penalty_exceptions_include_stop_sequences, # noqa: E501
"best_of": self.best_of,
"logit_bias": self.logit_bias,
"log_probs": self.log_probs,
"tokens": self.tokens,
"disable_optimizations": self.disable_optimizations,
"minimum_tokens": self.minimum_tokens,
"echo": self.echo,
"use_multiplicative_frequency_penalty": self.use_multiplicative_frequency_penalty, # noqa: E501
"sequence_penalty": self.sequence_penalty,
"sequence_penalty_min_length": self.sequence_penalty_min_length,
"use_multiplicative_sequence_penalty": self.use_multiplicative_sequence_penalty, # noqa: E501
"completion_bias_inclusion": self.completion_bias_inclusion,
"completion_bias_inclusion_first_token_only": self.completion_bias_inclusion_first_token_only, # noqa: E501
"completion_bias_exclusion": self.completion_bias_exclusion,
"completion_bias_exclusion_first_token_only": self.completion_bias_exclusion_first_token_only, # noqa: E501
"contextual_control_threshold": self.contextual_control_threshold,
"control_log_additive": self.control_log_additive,
"repetition_penalties_include_completion": self.repetition_penalties_include_completion, # noqa: E501
"raw_completion": self.raw_completion,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "alpeh_alpha"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Aleph Alpha's completion endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = alpeh_alpha("Tell me a joke.")
"""
from aleph_alpha_client import CompletionRequest, Prompt
params = self._default_params
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
params["stop_sequences"] = self.stop_sequences
else:
params["stop_sequences"] = stop
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
response = self.client.complete(model=self.model, request=request)
text = response.completions[0].completion
# If stop tokens are provided, Aleph Alpha's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None or self.stop_sequences is not None:
text = enforce_stop_tokens(text, params["stop_sequences"])
return text