together: fix chat model and embedding classes (#21353)

pull/21355/head langchain-together==0.1.1
Erick Friis 3 weeks ago committed by GitHub
parent d6ef5fe86a
commit bb81ae5c8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -59,7 +59,7 @@ class ChatTogether(BaseChatOpenAI):
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
together_api_base: Optional[str] = Field(
default="https://api.together.ai/v1/chat/completions", alias="base_url"
default="https://api.together.ai/v1/", alias="base_url"
)
@root_validator()

@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
"""Embeddings model name to use.
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
"""
dimensions: Optional[int] = None
@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""API Key for Solar API."""
together_api_base: str = Field(
default="https://api.together.ai/v1/embeddings", alias="base_url"
default="https://api.together.ai/v1/", alias="base_url"
)
"""Endpoint URL to use."""
embedding_ctx_length: int = 4096
@ -166,12 +166,18 @@ class TogetherEmbeddings(BaseModel, Embeddings):
"default_query": values["default_query"],
}
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
sync_specific = (
{"http_client": values["http_client"]} if values["http_client"] else {}
)
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
async_specific = (
{"http_client": values["http_async_client"]}
if values["http_async_client"]
else {}
)
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).embeddings
@ -179,8 +185,6 @@ class TogetherEmbeddings(BaseModel, Embeddings):
@property
def _invocation_params(self) -> Dict[str, Any]:
self.model = self.model.replace("-query", "").replace("-passage", "")
params: Dict = {"model": self.model, **self.model_kwargs}
if self.dimensions is not None:
params["dimensions"] = self.dimensions
@ -197,7 +201,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
"""
embeddings = []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
params["model"] = params["model"]
for text in texts:
response = self.client.create(input=text, **params)
@ -217,7 +221,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
Embedding for the text.
"""
params = self._invocation_params
params["model"] = params["model"] + "-query"
params["model"] = params["model"]
response = self.client.create(input=text, **params)
@ -236,7 +240,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
"""
embeddings = []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
params["model"] = params["model"]
for text in texts:
response = await self.async_client.create(input=text, **params)
@ -256,7 +260,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
Embedding for the text.
"""
params = self._invocation_params
params["model"] = params["model"] + "-query"
params["model"] = params["model"]
response = await self.async_client.create(input=text, **params)

@ -17,5 +17,5 @@ class TestTogethertandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "meta-llama/Llama-3-8b-chat-hf",
"model": "mistralai/Mistral-7B-Instruct-v0.1",
}

Loading…
Cancel
Save