You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/ai21/langchain_ai21/llms.py

175 lines
5.1 KiB
Python

import asyncio
from functools import partial
from typing import (
Any,
List,
Mapping,
Optional,
)
from ai21.models import CompletionsResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_ai21.ai21_base import AI21Base
class AI21LLM(BaseLLM, AI21Base):
"""AI21LLM large language models.
Example:
.. code-block:: python
from langchain_ai21 import AI21LLM
model = AI21LLM()
"""
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1
"""The number of responses to generate for a given prompt."""
max_tokens: int = 16
"""The maximum number of tokens to generate for each response."""
min_tokens: int = 0
"""The minimum number of tokens to generate for each response."""
temperature: float = 0.7
"""A value controlling the "creativity" of the model's responses."""
top_p: float = 1
"""A value controlling the diversity of the model's responses."""
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""
custom_model: Optional[str] = None
epoch: Optional[int] = None
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "ai21-llm"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.custom_model is not None:
base_params["custom_model"] = self.custom_model
if self.epoch is not None:
base_params["epoch"] = self.epoch
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _build_params_for_request(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Mapping[str, Any]:
params = {}
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
token_count = 0
params = self._build_params_for_request(stop=stop, **kwargs)
for prompt in prompts:
response = self._invoke_completion(prompt=prompt, **params)
generation = self._response_to_generation(response)
generations.append(generation)
token_count += self.client.count_tokens(prompt)
llm_output = {"token_count": token_count, "model_name": self.model}
return LLMResult(generations=generations, llm_output=llm_output)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# Change implementation if integration natively supports async generation.
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)
def _invoke_completion(
self,
prompt: str,
**kwargs: Any,
) -> CompletionsResponse:
return self.client.completion.create(
prompt=prompt,
**kwargs,
)
def _response_to_generation(
self, response: CompletionsResponse
) -> List[Generation]:
return [
Generation(
text=completion.data.text,
generation_info=completion.to_dict(),
)
for completion in response.completions
]