update azure embedding docs (#13091)

pull/13093/head
Bagatur 8 months ago committed by GitHub
parent 9fdfac22c2
commit 1703f132c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,95 @@
"source": [ "source": [
"# AzureOpenAI\n", "# AzureOpenAI\n",
"\n", "\n",
"Let's load the OpenAI Embedding class with environment variables set to indicate to use Azure endpoints." "Let's load the Azure OpenAI Embedding class with environment variables set to indicate to use Azure endpoints."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8a6ed30d-806f-4800-b5fd-d04126be9060",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"AZURE_OPENAI_API_KEY\"] = \"...\"\n",
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = \"https://<yout-endpoint>.openai.azure.com/\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "20179bc7-3f71-4909-be12-d38bce009b18",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import AzureOpenAIEmbeddings\n",
"\n",
"embeddings = AzureOpenAIEmbeddings(azure_deployment=\"<your-embeddings-deployment-name>\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f8cb9dca-738b-450f-9986-5c3efd3c6eb3",
"metadata": {},
"outputs": [],
"source": [
"text = \"this is a test document\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0fae0295-b117-4a5a-8b98-500c79306551",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "65a01ddd-0bbf-444f-a87f-93af25ef902c",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "45771052-68ca-4e03-9c4f-a0c7796d9442",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[-0.012222584727053133,\n",
" 0.0072103982392216145,\n",
" -0.014818063280923775,\n",
" -0.026444746872933557,\n",
" -0.0034330499700826883]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"doc_result[0][:5]"
]
},
{
"cell_type": "markdown",
"id": "e66ec1f2-6768-4ee5-84bf-a2d76adc20c8",
"metadata": {},
"source": [
"## [Legacy] When using `openai<1`"
] ]
}, },
{ {
@ -79,9 +167,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "poetry-venv",
"language": "python", "language": "python",
"name": "python3" "name": "poetry-venv"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

@ -21,7 +21,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
Example: `https://example-resource.azure.openai.com/` Example: `https://example-resource.azure.openai.com/`
""" """
azure_deployment: Optional[str] = None deployment: Optional[str] = Field(default=None, alias="azure_deployment")
"""A model deployment. """A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`. If given sets the base client URL to include `/deployments/{azure_deployment}`.
@ -104,15 +104,15 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
f"(or alias `base_url`). Updating `openai_api_base` from " f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}." f"{openai_api_base} to {values['openai_api_base']}."
) )
if values["azure_deployment"]: if values["deployment"]:
warnings.warn( warnings.warn(
"As of openai>=1.0.0, if `azure_deployment` (or alias " "As of openai>=1.0.0, if `deployment` (or alias "
"`azure_deployment`) is specified then " "`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. " "`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `azure_deployment` (or alias `azure_deployment`) " "Instead use `deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`." "and `azure_endpoint`."
) )
if values["azure_deployment"] not in values["openai_api_base"]: if values["deployment"] not in values["openai_api_base"]:
warnings.warn( warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` " "As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be " "(or alias `base_url`) is specified it is expected to be "
@ -122,13 +122,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
f"{values['openai_api_base']}." f"{values['openai_api_base']}."
) )
values["openai_api_base"] += ( values["openai_api_base"] += (
"/deployments/" + values["azure_deployment"] "/deployments/" + values["deployment"]
) )
values["azure_deployment"] = None values["deployment"] = None
client_params = { client_params = {
"api_version": values["openai_api_version"], "api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"], "azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["azure_deployment"], "azure_deployment": values["deployment"],
"api_key": values["openai_api_key"], "api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"], "azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"], "azure_ad_token_provider": values["azure_ad_token_provider"],

@ -17,6 +17,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
Union, Union,
cast,
) )
import numpy as np import numpy as np
@ -182,7 +183,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
async_client: Any = None #: :meta private: async_client: Any = None #: :meta private:
model: str = "text-embedding-ada-002" model: str = "text-embedding-ada-002"
# to support Azure OpenAI Service custom deployment names # to support Azure OpenAI Service custom deployment names
deployment: str = model deployment: Optional[str] = model
# TODO: Move to AzureOpenAIEmbeddings. # TODO: Move to AzureOpenAIEmbeddings.
openai_api_version: Optional[str] = Field(default=None, alias="api_version") openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
@ -546,7 +547,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
# NOTE: to keep things simple, we assume the list may contain texts longer # NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function. # than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment) engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine)
async def aembed_documents( async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0 self, texts: List[str], chunk_size: Optional[int] = 0
@ -563,7 +565,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
# NOTE: to keep things simple, we assume the list may contain texts longer # NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function. # than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, engine=self.deployment) engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.

Loading…
Cancel
Save