From 2366e71bed118cc40f6b1714baa05fcba6b8dd73 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 1 May 2023 21:34:16 -0700 Subject: [PATCH] Harrison/azure openai (#3942) Co-authored-by: Saverio Proto --- langchain/embeddings/openai.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index c9a5065d..52fbf94b 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -94,7 +94,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): from langchain.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings( deployment="your-embeddings-deployment-name", - model="your-embeddings-model-name" + model="your-embeddings-model-name", + api_base="https://your-endpoint.openai.azure.com/", + api_type="azure", ) text = "This is a test query." query_result = embeddings.embed_query(text) @@ -104,6 +106,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model # to support Azure OpenAI Service custom deployment names + openai_api_version: str = "2022-12-01" + # to support Azure OpenAI Service custom endpoints + openai_api_base: Optional[str] = None + # to support Azure OpenAI Service custom endpoints + openai_api_type: Optional[str] = None embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None openai_organization: Optional[str] = None @@ -125,6 +132,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) + openai_api_base = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) + openai_api_type = get_from_dict_or_env( + values, + "openai_api_type", + "OPENAI_API_TYPE", + default="", + ) + openai_api_version = get_from_dict_or_env( + values, + "openai_api_version", + "OPENAI_API_VERSION", + ) openai_organization = get_from_dict_or_env( values, "openai_organization", @@ -137,6 +161,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai.api_key = openai_api_key if openai_organization: openai.organization = openai_organization + if openai_api_base: + openai.api_base = openai_api_base + openai.api_version = openai_api_version + if openai_api_type: + openai.api_type = openai_api_type values["client"] = openai.Embedding except ImportError: raise ValueError(