You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/ai21/langchain_ai21/chat_models.py

202 lines
6.1 KiB
Python

import asyncio
from functools import partial
from typing import Any, List, Mapping, Optional, Tuple, cast
from ai21.models import ChatMessage, RoleType
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_ai21.ai21_base import AI21Base
def _get_system_message_from_message(message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise ValueError(
f"System Message must be of type str. Got {type(message.content)}"
)
return message.content
def _convert_messages_to_ai21_messages(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[ChatMessage]]:
system_message = None
converted_messages: List[ChatMessage] = []
for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
else:
system_message = _get_system_message_from_message(message)
else:
converted_message = _convert_message_to_ai21_message(message)
converted_messages.append(converted_message)
return system_message, converted_messages
def _convert_message_to_ai21_message(
message: BaseMessage,
) -> ChatMessage:
content = cast(str, message.content)
role = None
if isinstance(message, HumanMessage):
role = RoleType.USER
elif isinstance(message, AIMessage):
role = RoleType.ASSISTANT
if not role:
raise ValueError(
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
return ChatMessage(role=role, text=content)
def _pop_system_messages(messages: List[BaseMessage]) -> List[SystemMessage]:
system_message_indexes = [
i for i, message in enumerate(messages) if isinstance(message, SystemMessage)
]
return [cast(SystemMessage, messages.pop(i)) for i in system_message_indexes]
class ChatAI21(BaseChatModel, AI21Base):
"""ChatAI21 chat model.
Example:
.. code-block:: python
from langchain_ai21 import ChatAI21
model = ChatAI21()
"""
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1
"""The number of responses to generate for a given prompt."""
max_tokens: int = 16
"""The maximum number of tokens to generate for each response."""
min_tokens: int = 0
"""The minimum number of tokens to generate for each response."""
temperature: float = 0.7
"""A value controlling the "creativity" of the model's responses."""
top_p: float = 1
"""A value controlling the diversity of the model's responses."""
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-ai21"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _build_params_for_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Mapping[str, Any]:
params = {}
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
"system": system or "",
"messages": ai21_messages,
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
response = self.client.chat.create(**params)
outputs = response.outputs
message = AIMessage(content=outputs[0].text)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
)