|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
"""Azure OpenAI embeddings wrapper."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
@ -57,6 +58,8 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|
|
|
|
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
|
|
|
|
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
|
|
|
|
validate_base_url: bool = True
|
|
|
|
|
chunk_size: int = 2048
|
|
|
|
|
"""Maximum number of texts to embed in each batch"""
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
@ -102,7 +105,11 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|
|
|
|
# Azure OpenAI embedding models allow a maximum of 2048 texts
|
|
|
|
|
# at a time in each batch
|
|
|
|
|
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#best-practices
|
|
|
|
|
values["chunk_size"] = min(values["chunk_size"], 2048)
|
|
|
|
|
if values["chunk_size"] > 2048:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Azure OpenAI embeddings only allow a maximum of 2048 texts at a time "
|
|
|
|
|
"in each batch."
|
|
|
|
|
)
|
|
|
|
|
# For backwards compatibility. Before openai v1, no distinction was made
|
|
|
|
|
# between azure_endpoint and base_url (openai_api_base).
|
|
|
|
|
openai_api_base = values["openai_api_base"]
|
|
|
|
@ -126,12 +133,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
|
|
|
|
"api_version": values["openai_api_version"],
|
|
|
|
|
"azure_endpoint": values["azure_endpoint"],
|
|
|
|
|
"azure_deployment": values["deployment"],
|
|
|
|
|
"api_key": values["openai_api_key"].get_secret_value()
|
|
|
|
|
if values["openai_api_key"]
|
|
|
|
|
else None,
|
|
|
|
|
"azure_ad_token": values["azure_ad_token"].get_secret_value()
|
|
|
|
|
if values["azure_ad_token"]
|
|
|
|
|
else None,
|
|
|
|
|
"api_key": (
|
|
|
|
|
values["openai_api_key"].get_secret_value()
|
|
|
|
|
if values["openai_api_key"]
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
"azure_ad_token": (
|
|
|
|
|
values["azure_ad_token"].get_secret_value()
|
|
|
|
|
if values["azure_ad_token"]
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
|
|
|
|
"organization": values["openai_organization"],
|
|
|
|
|
"base_url": values["openai_api_base"],
|
|
|
|
|