diff --git a/libs/partners/ai21/langchain_ai21/ai21_base.py b/libs/partners/ai21/langchain_ai21/ai21_base.py index 0c7c79f64c..5b8fcca0f1 100644 --- a/libs/partners/ai21/langchain_ai21/ai21_base.py +++ b/libs/partners/ai21/langchain_ai21/ai21_base.py @@ -16,9 +16,17 @@ class AI21Base(BaseModel): client: Any = Field(default=None, exclude=True) #: :meta private: api_key: Optional[SecretStr] = None + """API key for AI21 API.""" api_host: Optional[str] = None + """Host URL""" timeout_sec: Optional[float] = None + """Timeout in seconds. + + If not set, it will default to the value of the environment + variable `AI21_TIMEOUT_SEC` or 300 seconds. + """ num_retries: Optional[int] = None + """Maximum number of retries for API requests before giving up.""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: diff --git a/libs/partners/ai21/langchain_ai21/embeddings.py b/libs/partners/ai21/langchain_ai21/embeddings.py index 87ef389470..9d18c55e4b 100644 --- a/libs/partners/ai21/langchain_ai21/embeddings.py +++ b/libs/partners/ai21/langchain_ai21/embeddings.py @@ -15,18 +15,64 @@ def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[Lis class AI21Embeddings(Embeddings, AI21Base): - """AI21 embedding model. + """AI21 embedding model integration. - To use, you should have the 'AI21_API_KEY' environment variable set - or pass as a named parameter to the constructor. + Install ``langchain_ai21`` and set environment variable ``AI21_API_KEY``. - Example: + .. code-block:: bash + + pip install -U langchain_ai21 + export AI21_API_KEY="your-api-key" + + Key init args — client params: + api_key: Optional[SecretStr] + batch_size: int + The number of texts that will be sent to the API in each batch. + Use larger batch sizes if working with many short texts. This will reduce + the number of API calls made, and can improve the time it takes to embed + a large number of texts. + num_retries: Optional[int] + Maximum number of retries for API requests before giving up. + timeout_sec: Optional[float] + Timeout in seconds for API requests. If not set, it will default to the + value of the environment variable `AI21_TIMEOUT_SEC` or 300 seconds. + + See full list of supported init args and their descriptions in the params section. + + Instantiate: .. code-block:: python from langchain_ai21 import AI21Embeddings - embeddings = AI21Embeddings() - query_result = embeddings.embed_query("Hello embeddings world!") + embed = AI21Embeddings( + # api_key="...", + # batch_size=128, + ) + + Embed single text: + .. code-block:: python + + input_text = "The meaning of life is 42" + vector = embed.embed_query(input_text) + print(vector[:3]) + + .. code-block:: python + + [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] + + Embed multiple texts: + .. code-block:: python + + input_texts = ["Document 1...", "Document 2..."] + vectors = embed.embed_documents(input_texts) + print(len(vectors)) + # The first 3 coordinates for the first vector + print(vectors[0][:3]) + + .. code-block:: python + + 2 + [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] """ batch_size: int = _DEFAULT_BATCH_SIZE