update azure embedding docs (#13091)

pull/13093/head
Bagatur 8 months ago committed by GitHub
parent 9fdfac22c2
commit 1703f132c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,95 @@
"source": [
"# AzureOpenAI\n",
"\n",
"Let's load the OpenAI Embedding class with environment variables set to indicate to use Azure endpoints."
"Let's load the Azure OpenAI Embedding class with environment variables set to indicate to use Azure endpoints."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8a6ed30d-806f-4800-b5fd-d04126be9060",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"AZURE_OPENAI_API_KEY\"] = \"...\"\n",
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = \"https://<yout-endpoint>.openai.azure.com/\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "20179bc7-3f71-4909-be12-d38bce009b18",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import AzureOpenAIEmbeddings\n",
"\n",
"embeddings = AzureOpenAIEmbeddings(azure_deployment=\"<your-embeddings-deployment-name>\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f8cb9dca-738b-450f-9986-5c3efd3c6eb3",
"metadata": {},
"outputs": [],
"source": [
"text = \"this is a test document\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0fae0295-b117-4a5a-8b98-500c79306551",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "65a01ddd-0bbf-444f-a87f-93af25ef902c",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "45771052-68ca-4e03-9c4f-a0c7796d9442",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[-0.012222584727053133,\n",
" 0.0072103982392216145,\n",
" -0.014818063280923775,\n",
" -0.026444746872933557,\n",
" -0.0034330499700826883]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"doc_result[0][:5]"
]
},
{
"cell_type": "markdown",
"id": "e66ec1f2-6768-4ee5-84bf-a2d76adc20c8",
"metadata": {},
"source": [
"## [Legacy] When using `openai<1`"
]
},
{
@ -79,9 +167,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "poetry-venv",
"language": "python",
"name": "python3"
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {

@ -21,7 +21,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
Example: `https://example-resource.azure.openai.com/`
"""
azure_deployment: Optional[str] = None
deployment: Optional[str] = Field(default=None, alias="azure_deployment")
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
@ -104,15 +104,15 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["azure_deployment"]:
if values["deployment"]:
warnings.warn(
"As of openai>=1.0.0, if `azure_deployment` (or alias "
"As of openai>=1.0.0, if `deployment` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `azure_deployment` (or alias `azure_deployment`) "
"Instead use `deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["azure_deployment"] not in values["openai_api_base"]:
if values["deployment"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
@ -122,13 +122,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["azure_deployment"]
"/deployments/" + values["deployment"]
)
values["azure_deployment"] = None
values["deployment"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["azure_deployment"],
"azure_deployment": values["deployment"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],

@ -17,6 +17,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import numpy as np
@ -182,7 +183,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
async_client: Any = None #: :meta private:
model: str = "text-embedding-ada-002"
# to support Azure OpenAI Service custom deployment names
deployment: str = model
deployment: Optional[str] = model
# TODO: Move to AzureOpenAIEmbeddings.
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
@ -546,7 +547,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment)
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine)
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
@ -563,7 +565,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, engine=self.deployment)
engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine)
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.

Loading…
Cancel
Save