improve openai embeddings (#351)

add more formal support for explicitly specifying each model, but in a
backwards compatible way
harrison/agent_multi_inputs^2
Harrison Chase 2 years ago committed by GitHub
parent 428508bd75
commit ed143b598f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,9 +22,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
client: Any #: :meta private:
model_name: str = "babbage"
"""Model name to use."""
document_model_name: str = "text-embedding-ada-002"
query_model_name: str = "text-embedding-ada-002"
openai_api_key: Optional[str] = None
class Config:
@ -32,6 +31,26 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
# TODO: deprecate this
@root_validator(pre=True)
def get_model_names(cls, values: Dict) -> Dict:
"""Get model names from just old model name."""
if "model_name" in values:
if "document_model_name" in values:
raise ValueError(
"Both `model_name` and `document_model_name` were provided, "
"but only one should be."
)
if "query_model_name" in values:
raise ValueError(
"Both `model_name` and `query_model_name` were provided, "
"but only one should be."
)
model_name = values.pop("model_name")
values["document_model_name"] = f"text-search-{model_name}-doc-001"
values["query_model_name"] = f"text-search-{model_name}-query-001"
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -66,7 +85,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
responses = [
self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001")
self._embedding_func(text, engine=self.document_model_name)
for text in texts
]
return responses
@ -80,7 +99,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(
text, engine=f"text-search-{self.model_name}-query-001"
)
embedding = self._embedding_func(text, engine=self.query_model_name)
return embedding

Loading…
Cancel
Save