fix: separate model and deployment for OpenAIEmbeddings (#3076)

Separated the deployment from model to support Azure OpenAI Embeddings
properly.
Also removed the deprecated document_model_name and query_model_name
attributes.
fix_agent_callbacks
Tunay Okumus 1 year ago committed by GitHub
parent 4adfd790f0
commit 6e48107734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -92,7 +92,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="your-embeddings-deployment-name")
embeddings = OpenAIEmbeddings(
deployment="your-embeddings-deployment-name",
model="your-embeddings-model-name"
)
text = "This is a test query."
query_result = embeddings.embed_query(text)
@ -100,12 +103,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private:
model: str = "text-embedding-ada-002"
# TODO: deprecate these two in favor of model
# https://community.openai.com/t/api-update-engines-models/18597
# https://github.com/openai/openai-python/issues/132
document_model_name: str = "text-embedding-ada-002"
query_model_name: str = "text-embedding-ada-002"
deployment: str = model # to support Azure OpenAI Service custom deployment names
embedding_ctx_length: int = 8191
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
@ -121,51 +119,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
# TODO: deprecate this
@root_validator(pre=True)
def get_model_names(cls, values: Dict) -> Dict:
# model_name is for first generation, and model is for second generation.
# Both are not allowed together.
if "model_name" in values and "model" in values:
raise ValueError(
"Both `model_name` and `model` were provided, "
"but only one should be."
)
"""Get model names from just old model name."""
if "model_name" in values:
if "document_model_name" in values:
raise ValueError(
"Both `model_name` and `document_model_name` were provided, "
"but only one should be."
)
if "query_model_name" in values:
raise ValueError(
"Both `model_name` and `query_model_name` were provided, "
"but only one should be."
)
model_name = values.pop("model_name")
values["document_model_name"] = f"text-search-{model_name}-doc-001"
values["query_model_name"] = f"text-search-{model_name}-query-001"
# Set document/query model names from model parameter.
if "model" in values:
if "document_model_name" in values:
raise ValueError(
"Both `model` and `document_model_name` were provided, "
"but only one should be."
)
if "query_model_name" in values:
raise ValueError(
"Both `model` and `query_model_name` were provided, "
"but only one should be."
)
model = values.get("model")
values["document_model_name"] = model
values["query_model_name"] = model
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -203,7 +156,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.document_model_name)
encoding = tiktoken.model.encoding_for_model(self.model)
for i, text in enumerate(texts):
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
@ -222,7 +175,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
response = embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
engine=self.document_model_name,
engine=self.deployment,
)
batched_embeddings += [r["embedding"] for r in response["data"]]
@ -272,7 +225,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
# handle batches of large input text
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings(texts, engine=self.document_model_name)
return self._get_len_safe_embeddings(texts, engine=self.deployment)
else:
results = []
_chunk_size = chunk_size or self.chunk_size
@ -280,7 +233,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
response = embed_with_retry(
self,
input=texts[i : i + _chunk_size],
engine=self.document_model_name,
engine=self.deployment,
)
results += [r["embedding"] for r in response["data"]]
return results
@ -294,5 +247,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.query_model_name)
embedding = self._embedding_func(text, engine=self.deployment)
return embedding

Loading…
Cancel
Save