@ -618,10 +618,10 @@ class SambaStudio(LLM):
from langchain_community . llms . sambanova import Sambaverse
SambaStudio (
base_url= " your SambaStudio environment URL " ,
project_id= set with your SambaStudio project ID . ,
endpoint_id= set with your SambaStudio endpoint ID . ,
api_token = set with your SambaStudio endpoint API key . ,
sambastudio_ base_url= " your SambaStudio environment URL " ,
sambastudio_ project_id= set with your SambaStudio project ID . ,
sambastudio_ endpoint_id= set with your SambaStudio endpoint ID . ,
sambastudio_api_key = set with your SambaStudio endpoint API key . ,
streaming = false
model_kwargs = {
" do_sample " : False ,
@ -634,16 +634,16 @@ class SambaStudio(LLM):
)
"""
base_url: str = " "
sambastudio_ base_url: str = " "
""" Base url to use """
project_id: str = " "
sambastudio_ project_id: str = " "
""" Project id on sambastudio for model """
endpoint_id: str = " "
sambastudio_ endpoint_id: str = " "
""" endpoint id on sambastudio for model """
api_key: str = " "
sambastudio_ api_key: str = " "
""" sambastudio api key """
model_kwargs : Optional [ dict ] = None
@ -674,16 +674,16 @@ class SambaStudio(LLM):
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key and python package exists in environment. """
values [ " base_url" ] = get_from_dict_or_env (
values [ " sambastudio_ base_url" ] = get_from_dict_or_env (
values , " sambastudio_base_url " , " SAMBASTUDIO_BASE_URL "
)
values [ " project_id" ] = get_from_dict_or_env (
values [ " sambastudio_ project_id" ] = get_from_dict_or_env (
values , " sambastudio_project_id " , " SAMBASTUDIO_PROJECT_ID "
)
values [ " endpoint_id" ] = get_from_dict_or_env (
values [ " sambastudio_ endpoint_id" ] = get_from_dict_or_env (
values , " sambastudio_endpoint_id " , " SAMBASTUDIO_ENDPOINT_ID "
)
values [ " api_key" ] = get_from_dict_or_env (
values [ " sambastudio_ api_key" ] = get_from_dict_or_env (
values , " sambastudio_api_key " , " SAMBASTUDIO_API_KEY "
)
return values
@ -729,7 +729,11 @@ class SambaStudio(LLM):
ValueError : If the prediction fails .
"""
response = sdk . nlp_predict (
self . project_id , self . endpoint_id , self . api_key , prompt , tuning_params
self . sambastudio_project_id ,
self . sambastudio_endpoint_id ,
self . sambastudio_api_key ,
prompt ,
tuning_params ,
)
if response [ " status_code " ] != 200 :
optional_detail = response [ " detail " ]
@ -755,7 +759,7 @@ class SambaStudio(LLM):
Raises :
ValueError : If the prediction fails .
"""
ss_endpoint = SSEndpointHandler ( self . base_url)
ss_endpoint = SSEndpointHandler ( self . sambastudio_ base_url)
tuning_params = self . _get_tuning_params ( stop )
return self . _handle_nlp_predict ( ss_endpoint , prompt , tuning_params )
@ -774,7 +778,11 @@ class SambaStudio(LLM):
An iterator of GenerationChunks .
"""
for chunk in sdk . nlp_predict_stream (
self . project_id , self . endpoint_id , self . api_key , prompt , tuning_params
self . sambastudio_project_id ,
self . sambastudio_endpoint_id ,
self . sambastudio_api_key ,
prompt ,
tuning_params ,
) :
yield chunk
@ -794,7 +802,7 @@ class SambaStudio(LLM):
Returns :
The string generated by the model .
"""
ss_endpoint = SSEndpointHandler ( self . base_url)
ss_endpoint = SSEndpointHandler ( self . sambastudio_ base_url)
tuning_params = self . _get_tuning_params ( stop )
try :
if self . streaming :