|
|
|
@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
|
|
|
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
|
|
|
|
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
|
|
|
|
|
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
|
|
|
|
"""Embeddings model name to use.
|
|
|
|
|
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
|
|
|
|
|
"""
|
|
|
|
|
dimensions: Optional[int] = None
|
|
|
|
@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
|
|
|
|
"""API Key for Solar API."""
|
|
|
|
|
together_api_base: str = Field(
|
|
|
|
|
default="https://api.together.ai/v1/embeddings", alias="base_url"
|
|
|
|
|
default="https://api.together.ai/v1/", alias="base_url"
|
|
|
|
|
)
|
|
|
|
|
"""Endpoint URL to use."""
|
|
|
|
|
embedding_ctx_length: int = 4096
|
|
|
|
@ -166,12 +166,18 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"default_query": values["default_query"],
|
|
|
|
|
}
|
|
|
|
|
if not values.get("client"):
|
|
|
|
|
sync_specific = {"http_client": values["http_client"]}
|
|
|
|
|
sync_specific = (
|
|
|
|
|
{"http_client": values["http_client"]} if values["http_client"] else {}
|
|
|
|
|
)
|
|
|
|
|
values["client"] = openai.OpenAI(
|
|
|
|
|
**client_params, **sync_specific
|
|
|
|
|
).embeddings
|
|
|
|
|
if not values.get("async_client"):
|
|
|
|
|
async_specific = {"http_client": values["http_async_client"]}
|
|
|
|
|
async_specific = (
|
|
|
|
|
{"http_client": values["http_async_client"]}
|
|
|
|
|
if values["http_async_client"]
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
values["async_client"] = openai.AsyncOpenAI(
|
|
|
|
|
**client_params, **async_specific
|
|
|
|
|
).embeddings
|
|
|
|
@ -179,8 +185,6 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
|
|
|
self.model = self.model.replace("-query", "").replace("-passage", "")
|
|
|
|
|
|
|
|
|
|
params: Dict = {"model": self.model, **self.model_kwargs}
|
|
|
|
|
if self.dimensions is not None:
|
|
|
|
|
params["dimensions"] = self.dimensions
|
|
|
|
@ -197,7 +201,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"""
|
|
|
|
|
embeddings = []
|
|
|
|
|
params = self._invocation_params
|
|
|
|
|
params["model"] = params["model"] + "-passage"
|
|
|
|
|
params["model"] = params["model"]
|
|
|
|
|
|
|
|
|
|
for text in texts:
|
|
|
|
|
response = self.client.create(input=text, **params)
|
|
|
|
@ -217,7 +221,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
Embedding for the text.
|
|
|
|
|
"""
|
|
|
|
|
params = self._invocation_params
|
|
|
|
|
params["model"] = params["model"] + "-query"
|
|
|
|
|
params["model"] = params["model"]
|
|
|
|
|
|
|
|
|
|
response = self.client.create(input=text, **params)
|
|
|
|
|
|
|
|
|
@ -236,7 +240,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"""
|
|
|
|
|
embeddings = []
|
|
|
|
|
params = self._invocation_params
|
|
|
|
|
params["model"] = params["model"] + "-passage"
|
|
|
|
|
params["model"] = params["model"]
|
|
|
|
|
|
|
|
|
|
for text in texts:
|
|
|
|
|
response = await self.async_client.create(input=text, **params)
|
|
|
|
@ -256,7 +260,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
Embedding for the text.
|
|
|
|
|
"""
|
|
|
|
|
params = self._invocation_params
|
|
|
|
|
params["model"] = params["model"] + "-query"
|
|
|
|
|
params["model"] = params["model"]
|
|
|
|
|
|
|
|
|
|
response = await self.async_client.create(input=text, **params)
|
|
|
|
|
|
|
|
|
|