diff --git a/libs/partners/ai21/langchain_ai21/ai21_base.py b/libs/partners/ai21/langchain_ai21/ai21_base.py index fa9f30ed80..0c7c79f64c 100644 --- a/libs/partners/ai21/langchain_ai21/ai21_base.py +++ b/libs/partners/ai21/langchain_ai21/ai21_base.py @@ -9,6 +9,8 @@ _DEFAULT_TIMEOUT_SEC = 300 class AI21Base(BaseModel): + """Base class for AI21 models.""" + class Config: arbitrary_types_allowed = True diff --git a/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py b/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py index 89fb488b1a..67f70ca381 100644 --- a/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py +++ b/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py @@ -14,8 +14,8 @@ _ROLE_TYPE = Union[str, RoleType] class ChatAdapter(ABC): - """ - Provides a common interface for the different Chat models available in AI21. + """Common interface for the different Chat models available in AI21. + It converts LangChain messages to AI21 messages. Calls the appropriate AI21 model API with the converted messages. """ @@ -77,6 +77,8 @@ class ChatAdapter(ABC): class J2ChatAdapter(ChatAdapter): + """Adapter for J2Chat models.""" + def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]: system_message = "" converted_messages = [] # type: ignore @@ -107,6 +109,8 @@ class J2ChatAdapter(ChatAdapter): class JambaChatCompletionsAdapter(ChatAdapter): + """Adapter for Jamba Chat Completions.""" + def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]: return { "messages": [ diff --git a/libs/partners/ai21/langchain_ai21/chat/chat_factory.py b/libs/partners/ai21/langchain_ai21/chat/chat_factory.py index 8f47a85059..1fdeb5436f 100644 --- a/libs/partners/ai21/langchain_ai21/chat/chat_factory.py +++ b/libs/partners/ai21/langchain_ai21/chat/chat_factory.py @@ -6,6 +6,14 @@ from langchain_ai21.chat.chat_adapter import ( def create_chat_adapter(model: str) -> ChatAdapter: + """Create a chat adapter based on the model. + + Args: + model: The model to create the chat adapter for. + + Returns: + The chat adapter. + """ if "j2" in model: return J2ChatAdapter() diff --git a/libs/partners/ai21/langchain_ai21/contextual_answers.py b/libs/partners/ai21/langchain_ai21/contextual_answers.py index 264f45460a..79b091adac 100644 --- a/libs/partners/ai21/langchain_ai21/contextual_answers.py +++ b/libs/partners/ai21/langchain_ai21/contextual_answers.py @@ -19,11 +19,15 @@ ContextType = Union[str, List[Union[Document, str]]] class ContextualAnswerInput(TypedDict): + """Input for the ContextualAnswers runnable.""" + context: ContextType question: str class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base): + """Runnable for the AI21 Contextual Answers API.""" + class Config: """Configuration for this pydantic object.""" diff --git a/libs/partners/ai21/langchain_ai21/embeddings.py b/libs/partners/ai21/langchain_ai21/embeddings.py index 97b7d68242..87ef389470 100644 --- a/libs/partners/ai21/langchain_ai21/embeddings.py +++ b/libs/partners/ai21/langchain_ai21/embeddings.py @@ -15,7 +15,8 @@ def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[Lis class AI21Embeddings(Embeddings, AI21Base): - """AI21 Embeddings embedding model. + """AI21 embedding model. + To use, you should have the 'AI21_API_KEY' environment variable set or pass as a named parameter to the constructor. diff --git a/libs/partners/ai21/langchain_ai21/llms.py b/libs/partners/ai21/langchain_ai21/llms.py index 0cba917bd5..0c0ae1822e 100644 --- a/libs/partners/ai21/langchain_ai21/llms.py +++ b/libs/partners/ai21/langchain_ai21/llms.py @@ -19,7 +19,7 @@ from langchain_ai21.ai21_base import AI21Base class AI21LLM(BaseLLM, AI21Base): - """AI21LLM large language models. + """AI21 large language models. Example: .. code-block:: python diff --git a/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py b/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py index cda0bba1b1..974ba6bc23 100644 --- a/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py +++ b/libs/partners/ai21/langchain_ai21/semantic_text_splitter.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class AI21SemanticTextSplitter(TextSplitter): """Splitting text into coherent and readable units, - based on distinct topics and lines + based on distinct topics and lines. """ def __init__(