|
|
|
@ -22,9 +22,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
|
|
|
|
model_name: str = "babbage"
|
|
|
|
|
"""Model name to use."""
|
|
|
|
|
|
|
|
|
|
document_model_name: str = "text-embedding-ada-002"
|
|
|
|
|
query_model_name: str = "text-embedding-ada-002"
|
|
|
|
|
openai_api_key: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
@ -32,6 +31,26 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
|
|
|
|
|
|
# TODO: deprecate this
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def get_model_names(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Get model names from just old model name."""
|
|
|
|
|
if "model_name" in values:
|
|
|
|
|
if "document_model_name" in values:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Both `model_name` and `document_model_name` were provided, "
|
|
|
|
|
"but only one should be."
|
|
|
|
|
)
|
|
|
|
|
if "query_model_name" in values:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Both `model_name` and `query_model_name` were provided, "
|
|
|
|
|
"but only one should be."
|
|
|
|
|
)
|
|
|
|
|
model_name = values.pop("model_name")
|
|
|
|
|
values["document_model_name"] = f"text-search-{model_name}-doc-001"
|
|
|
|
|
values["query_model_name"] = f"text-search-{model_name}-query-001"
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
@ -66,7 +85,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
List of embeddings, one for each text.
|
|
|
|
|
"""
|
|
|
|
|
responses = [
|
|
|
|
|
self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001")
|
|
|
|
|
self._embedding_func(text, engine=self.document_model_name)
|
|
|
|
|
for text in texts
|
|
|
|
|
]
|
|
|
|
|
return responses
|
|
|
|
@ -80,7 +99,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
Returns:
|
|
|
|
|
Embeddings for the text.
|
|
|
|
|
"""
|
|
|
|
|
embedding = self._embedding_func(
|
|
|
|
|
text, engine=f"text-search-{self.model_name}-query-001"
|
|
|
|
|
)
|
|
|
|
|
embedding = self._embedding_func(text, engine=self.query_model_name)
|
|
|
|
|
return embedding
|
|
|
|
|