@ -92,7 +92,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
os . environ [ " OPENAI_API_KEY " ] = " your AzureOpenAI key "
from langchain . embeddings . openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings ( model = " your-embeddings-deployment-name " )
embeddings = OpenAIEmbeddings (
deployment = " your-embeddings-deployment-name " ,
model = " your-embeddings-model-name "
)
text = " This is a test query. "
query_result = embeddings . embed_query ( text )
@ -100,12 +103,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client : Any #: :meta private:
model : str = " text-embedding-ada-002 "
# TODO: deprecate these two in favor of model
# https://community.openai.com/t/api-update-engines-models/18597
# https://github.com/openai/openai-python/issues/132
document_model_name : str = " text-embedding-ada-002 "
query_model_name : str = " text-embedding-ada-002 "
deployment : str = model # to support Azure OpenAI Service custom deployment names
embedding_ctx_length : int = 8191
openai_api_key : Optional [ str ] = None
openai_organization : Optional [ str ] = None
@ -121,51 +119,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
extra = Extra . forbid
# TODO: deprecate this
@root_validator ( pre = True )
def get_model_names ( cls , values : Dict ) - > Dict :
# model_name is for first generation, and model is for second generation.
# Both are not allowed together.
if " model_name " in values and " model " in values :
raise ValueError (
" Both `model_name` and `model` were provided, "
" but only one should be. "
)
""" Get model names from just old model name. """
if " model_name " in values :
if " document_model_name " in values :
raise ValueError (
" Both `model_name` and `document_model_name` were provided, "
" but only one should be. "
)
if " query_model_name " in values :
raise ValueError (
" Both `model_name` and `query_model_name` were provided, "
" but only one should be. "
)
model_name = values . pop ( " model_name " )
values [ " document_model_name " ] = f " text-search- { model_name } -doc-001 "
values [ " query_model_name " ] = f " text-search- { model_name } -query-001 "
# Set document/query model names from model parameter.
if " model " in values :
if " document_model_name " in values :
raise ValueError (
" Both `model` and `document_model_name` were provided, "
" but only one should be. "
)
if " query_model_name " in values :
raise ValueError (
" Both `model` and `query_model_name` were provided, "
" but only one should be. "
)
model = values . get ( " model " )
values [ " document_model_name " ] = model
values [ " query_model_name " ] = model
return values
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key and python package exists in environment. """
@ -203,7 +156,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
tokens = [ ]
indices = [ ]
encoding = tiktoken . model . encoding_for_model ( self . document_ model_name )
encoding = tiktoken . model . encoding_for_model ( self . model)
for i , text in enumerate ( texts ) :
# replace newlines, which can negatively affect performance.
text = text . replace ( " \n " , " " )
@ -222,7 +175,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
response = embed_with_retry (
self ,
input = tokens [ i : i + _chunk_size ] ,
engine = self . d ocument_model_name ,
engine = self . d eployment ,
)
batched_embeddings + = [ r [ " embedding " ] for r in response [ " data " ] ]
@ -272,7 +225,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
# handle batches of large input text
if self . embedding_ctx_length > 0 :
return self . _get_len_safe_embeddings ( texts , engine = self . d ocument_model_name )
return self . _get_len_safe_embeddings ( texts , engine = self . d eployment )
else :
results = [ ]
_chunk_size = chunk_size or self . chunk_size
@ -280,7 +233,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
response = embed_with_retry (
self ,
input = texts [ i : i + _chunk_size ] ,
engine = self . d ocument_model_name ,
engine = self . d eployment ,
)
results + = [ r [ " embedding " ] for r in response [ " data " ] ]
return results
@ -294,5 +247,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns :
Embedding for the text .
"""
embedding = self . _embedding_func ( text , engine = self . query_model_name )
embedding = self . _embedding_func ( text , engine = self . deployment )
return embedding