mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
ai21: docstrings (#23142)
Added missed docstrings. Format docstrings to the consistent format (used in the API Reference)
This commit is contained in:
parent
0c2ebe5f47
commit
a70b7a688e
@ -9,6 +9,8 @@ _DEFAULT_TIMEOUT_SEC = 300
|
||||
|
||||
|
||||
class AI21Base(BaseModel):
|
||||
"""Base class for AI21 models."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
@ -14,8 +14,8 @@ _ROLE_TYPE = Union[str, RoleType]
|
||||
|
||||
|
||||
class ChatAdapter(ABC):
|
||||
"""
|
||||
Provides a common interface for the different Chat models available in AI21.
|
||||
"""Common interface for the different Chat models available in AI21.
|
||||
|
||||
It converts LangChain messages to AI21 messages.
|
||||
Calls the appropriate AI21 model API with the converted messages.
|
||||
"""
|
||||
@ -77,6 +77,8 @@ class ChatAdapter(ABC):
|
||||
|
||||
|
||||
class J2ChatAdapter(ChatAdapter):
|
||||
"""Adapter for J2Chat models."""
|
||||
|
||||
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
system_message = ""
|
||||
converted_messages = [] # type: ignore
|
||||
@ -107,6 +109,8 @@ class J2ChatAdapter(ChatAdapter):
|
||||
|
||||
|
||||
class JambaChatCompletionsAdapter(ChatAdapter):
|
||||
"""Adapter for Jamba Chat Completions."""
|
||||
|
||||
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
|
@ -6,6 +6,14 @@ from langchain_ai21.chat.chat_adapter import (
|
||||
|
||||
|
||||
def create_chat_adapter(model: str) -> ChatAdapter:
|
||||
"""Create a chat adapter based on the model.
|
||||
|
||||
Args:
|
||||
model: The model to create the chat adapter for.
|
||||
|
||||
Returns:
|
||||
The chat adapter.
|
||||
"""
|
||||
if "j2" in model:
|
||||
return J2ChatAdapter()
|
||||
|
||||
|
@ -19,11 +19,15 @@ ContextType = Union[str, List[Union[Document, str]]]
|
||||
|
||||
|
||||
class ContextualAnswerInput(TypedDict):
|
||||
"""Input for the ContextualAnswers runnable."""
|
||||
|
||||
context: ContextType
|
||||
question: str
|
||||
|
||||
|
||||
class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
|
||||
"""Runnable for the AI21 Contextual Answers API."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
|
@ -15,7 +15,8 @@ def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[Lis
|
||||
|
||||
|
||||
class AI21Embeddings(Embeddings, AI21Base):
|
||||
"""AI21 Embeddings embedding model.
|
||||
"""AI21 embedding model.
|
||||
|
||||
To use, you should have the 'AI21_API_KEY' environment variable set
|
||||
or pass as a named parameter to the constructor.
|
||||
|
||||
|
@ -19,7 +19,7 @@ from langchain_ai21.ai21_base import AI21Base
|
||||
|
||||
|
||||
class AI21LLM(BaseLLM, AI21Base):
|
||||
"""AI21LLM large language models.
|
||||
"""AI21 large language models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AI21SemanticTextSplitter(TextSplitter):
|
||||
"""Splitting text into coherent and readable units,
|
||||
based on distinct topics and lines
|
||||
based on distinct topics and lines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
Loading…
Reference in New Issue
Block a user