@ -1,5 +1,6 @@
""" Wrapper around Together AI ' s Completion API. """
import logging
import warnings
from typing import Any , Dict , List , Optional
import requests
@ -34,13 +35,14 @@ class Together(LLM):
model = Together ( model_name = " mistralai/Mixtral-8x7B-Instruct-v0.1 " )
"""
base_url : str = " https://api.together.xyz/ inference "
""" Base inferen ce API URL."""
base_url : str = " https://api.together.xyz/ v1/completions "
""" Base compl etions API URL."""
together_api_key : SecretStr
""" Together AI API key. Get it here: https://api.together.xyz/settings/api-keys """
model : str
""" Model name. Available models listed here:
https : / / docs . together . ai / docs / inference - models
Base Models : https : / / docs . together . ai / docs / inference - models #language-models
Chat Models : https : / / docs . together . ai / docs / inference - models #chat-models
"""
temperature : Optional [ float ] = None
""" Model temperature. """
@ -82,13 +84,28 @@ class Together(LLM):
)
return values
@root_validator ( )
def validate_max_tokens ( cls , values : Dict ) - > Dict :
"""
The v1 completions endpoint , has max_tokens as required parameter .
Set a default value and warn if the parameter is missing .
"""
if values . get ( " max_tokens " ) is None :
warnings . warn (
" The completions endpoint, has ' max_tokens ' as required argument. "
" The default value is being set to 200 "
" Consider setting this value, when initializing LLM "
)
values [ " max_tokens " ] = 200 # Default Value
return values
@property
def _llm_type ( self ) - > str :
""" Return type of model. """
return " together "
def _format_output ( self , output : dict ) - > str :
return output [ " output " ] [ " choices " ] [ 0 ] [ " text " ]
return output [ " choices" ] [ 0 ] [ " text " ]
@staticmethod
def get_user_agent ( ) - > str :
@ -148,9 +165,6 @@ class Together(LLM):
)
data = response . json ( )
if data . get ( " status " ) != " finished " :
err_msg = data . get ( " error " , " Undefined Error " )
raise Exception ( err_msg )
output = self . _format_output ( data )
@ -203,9 +217,5 @@ class Together(LLM):
response_json = await response . json ( )
if response_json . get ( " status " ) != " finished " :
err_msg = response_json . get ( " error " , " Undefined Error " )
raise Exception ( err_msg )
output = self . _format_output ( response_json )
return output