mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
|
import asyncio
|
||
|
from functools import partial
|
||
|
from typing import (
|
||
|
Any,
|
||
|
List,
|
||
|
Optional,
|
||
|
)
|
||
|
|
||
|
from ai21.models import CompletionsResponse, Penalty
|
||
|
from langchain_core.callbacks import (
|
||
|
AsyncCallbackManagerForLLMRun,
|
||
|
CallbackManagerForLLMRun,
|
||
|
)
|
||
|
from langchain_core.language_models import BaseLLM
|
||
|
from langchain_core.outputs import Generation, LLMResult
|
||
|
|
||
|
from langchain_ai21.ai21_base import AI21Base
|
||
|
|
||
|
|
||
|
class AI21LLM(BaseLLM, AI21Base):
|
||
|
"""AI21LLM large language models.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_ai21 import AI21LLM
|
||
|
|
||
|
model = AI21LLM()
|
||
|
"""
|
||
|
|
||
|
model: str
|
||
|
"""Model type you wish to interact with.
|
||
|
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
|
||
|
|
||
|
num_results: int = 1
|
||
|
"""The number of responses to generate for a given prompt."""
|
||
|
|
||
|
max_tokens: int = 16
|
||
|
"""The maximum number of tokens to generate for each response."""
|
||
|
|
||
|
min_tokens: int = 0
|
||
|
"""The minimum number of tokens to generate for each response."""
|
||
|
|
||
|
temperature: float = 0.7
|
||
|
"""A value controlling the "creativity" of the model's responses."""
|
||
|
|
||
|
top_p: float = 1
|
||
|
"""A value controlling the diversity of the model's responses."""
|
||
|
|
||
|
top_k_returns: int = 0
|
||
|
"""The number of top-scoring tokens to consider for each generation step."""
|
||
|
|
||
|
frequency_penalty: Optional[Penalty] = None
|
||
|
"""A penalty applied to tokens that are frequently generated."""
|
||
|
|
||
|
presence_penalty: Optional[Penalty] = None
|
||
|
""" A penalty applied to tokens that are already present in the prompt."""
|
||
|
|
||
|
count_penalty: Optional[Penalty] = None
|
||
|
"""A penalty applied to tokens based on their frequency
|
||
|
in the generated responses."""
|
||
|
|
||
|
custom_model: Optional[str] = None
|
||
|
epoch: Optional[int] = None
|
||
|
|
||
|
class Config:
|
||
|
"""Configuration for this pydantic object."""
|
||
|
|
||
|
allow_population_by_field_name = True
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
"""Return type of LLM."""
|
||
|
return "ai21-llm"
|
||
|
|
||
|
def _generate(
|
||
|
self,
|
||
|
prompts: List[str],
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> LLMResult:
|
||
|
generations: List[List[Generation]] = []
|
||
|
token_count = 0
|
||
|
|
||
|
for prompt in prompts:
|
||
|
response = self._invoke_completion(
|
||
|
prompt=prompt, model=self.model, stop_sequences=stop, **kwargs
|
||
|
)
|
||
|
generation = self._response_to_generation(response)
|
||
|
generations.append(generation)
|
||
|
token_count += self.client.count_tokens(prompt)
|
||
|
|
||
|
llm_output = {"token_count": token_count, "model_name": self.model}
|
||
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||
|
|
||
|
async def _agenerate(
|
||
|
self,
|
||
|
prompts: List[str],
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> LLMResult:
|
||
|
# Change implementation if integration natively supports async generation.
|
||
|
return await asyncio.get_running_loop().run_in_executor(
|
||
|
None, partial(self._generate, **kwargs), prompts, stop, run_manager
|
||
|
)
|
||
|
|
||
|
def _invoke_completion(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
model: str,
|
||
|
stop_sequences: Optional[List[str]] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> CompletionsResponse:
|
||
|
return self.client.completion.create(
|
||
|
prompt=prompt,
|
||
|
model=model,
|
||
|
max_tokens=self.max_tokens,
|
||
|
num_results=self.num_results,
|
||
|
min_tokens=self.min_tokens,
|
||
|
temperature=self.temperature,
|
||
|
top_p=self.top_p,
|
||
|
top_k_return=self.top_k_returns,
|
||
|
custom_model=self.custom_model,
|
||
|
stop_sequences=stop_sequences,
|
||
|
frequency_penalty=self.frequency_penalty,
|
||
|
presence_penalty=self.presence_penalty,
|
||
|
count_penalty=self.count_penalty,
|
||
|
epoch=self.epoch,
|
||
|
)
|
||
|
|
||
|
def _response_to_generation(
|
||
|
self, response: CompletionsResponse
|
||
|
) -> List[Generation]:
|
||
|
return [
|
||
|
Generation(
|
||
|
text=completion.data.text,
|
||
|
generation_info=completion.to_dict(),
|
||
|
)
|
||
|
for completion in response.completions
|
||
|
]
|