mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
d6ef5fe86a
**Description:** Adding chat completions to the Together AI package, which is our most popular API. Also staying backwards compatible with the old API so folks can continue to use the completions API as well. Also moved the embedding API to use the OpenAI library to standardize it further. **Twitter handle:** @nutlope - [x] **Add tests and docs**: If you're adding a new integration, please include - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
"""Wrapper around Together AI's Chat Completions API."""
|
|
|
|
import os
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
)
|
|
|
|
import openai
|
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
|
from langchain_core.utils import (
|
|
convert_to_secret_str,
|
|
get_from_dict_or_env,
|
|
)
|
|
from langchain_openai.chat_models.base import BaseChatOpenAI
|
|
|
|
|
|
class ChatTogether(BaseChatOpenAI):
|
|
"""ChatTogether chat model.
|
|
|
|
To use, you should have the environment variable `TOGETHER_API_KEY`
|
|
set with your API key or pass it as a named parameter to the constructor.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_together import ChatTogether
|
|
|
|
|
|
model = ChatTogether()
|
|
"""
|
|
|
|
@property
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
return {"together_api_key": "TOGETHER_API_KEY"}
|
|
|
|
@classmethod
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
return ["langchain", "chat_models", "together"]
|
|
|
|
@property
|
|
def lc_attributes(self) -> Dict[str, Any]:
|
|
attributes: Dict[str, Any] = {}
|
|
|
|
if self.together_api_base:
|
|
attributes["together_api_base"] = self.together_api_base
|
|
|
|
return attributes
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of chat model."""
|
|
return "together-chat"
|
|
|
|
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
|
|
"""Model name to use."""
|
|
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
|
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
|
|
together_api_base: Optional[str] = Field(
|
|
default="https://api.together.ai/v1/chat/completions", alias="base_url"
|
|
)
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
if values["n"] < 1:
|
|
raise ValueError("n must be at least 1.")
|
|
if values["n"] > 1 and values["streaming"]:
|
|
raise ValueError("n must be 1 when streaming.")
|
|
|
|
values["together_api_key"] = convert_to_secret_str(
|
|
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
|
|
)
|
|
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
|
"TOGETHER_API_BASE"
|
|
)
|
|
|
|
client_params = {
|
|
"api_key": (
|
|
values["together_api_key"].get_secret_value()
|
|
if values["together_api_key"]
|
|
else None
|
|
),
|
|
"base_url": values["together_api_base"],
|
|
"timeout": values["request_timeout"],
|
|
"max_retries": values["max_retries"],
|
|
"default_headers": values["default_headers"],
|
|
"default_query": values["default_query"],
|
|
}
|
|
|
|
if not values.get("client"):
|
|
sync_specific = {"http_client": values["http_client"]}
|
|
values["client"] = openai.OpenAI(
|
|
**client_params, **sync_specific
|
|
).chat.completions
|
|
if not values.get("async_client"):
|
|
async_specific = {"http_client": values["http_async_client"]}
|
|
values["async_client"] = openai.AsyncOpenAI(
|
|
**client_params, **async_specific
|
|
).chat.completions
|
|
return values
|