community: minor changes sambanova integration (#21231)

- **Description:** fix: variable names in root validator not allowing
pass credentials as named parameters in llm instancing, also added
sambanova's sambaverse and sambastudio llms to __init__.py for module
import
pull/21347/head
Jorge Piedrahita Ortiz 3 weeks ago committed by GitHub
parent d9a61c0fa9
commit df1c10260c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
return SagemakerEndpoint
def _import_sambaverse() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import Sambaverse
return Sambaverse
def _import_sambastudio() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import SambaStudio
return SambaStudio
def _import_self_hosted() -> Type[BaseLLM]:
from langchain_community.llms.self_hosted import SelfHostedPipeline
@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any:
return _import_rwkv()
elif name == "SagemakerEndpoint":
return _import_sagemaker_endpoint()
elif name == "Sambaverse":
return _import_sambaverse()
elif name == "SambaStudio":
return _import_sambastudio()
elif name == "SelfHostedPipeline":
return _import_self_hosted()
elif name == "SelfHostedHuggingFaceLLM":
@ -922,6 +938,8 @@ __all__ = [
"RWKV",
"Replicate",
"SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"SparkLLM",
@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"replicate": _import_replicate,
"rwkv": _import_rwkv,
"sagemaker_endpoint": _import_sagemaker_endpoint,
"sambaverse": _import_sambaverse,
"sambastudio": _import_sambastudio,
"self_hosted": _import_self_hosted,
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
"stochasticai": _import_stochasticai,

@ -618,10 +618,10 @@ class SambaStudio(LLM):
from langchain_community.llms.sambanova import Sambaverse
SambaStudio(
base_url="your SambaStudio environment URL",
project_id=set with your SambaStudio project ID.,
endpoint_id=set with your SambaStudio endpoint ID.,
api_token= set with your SambaStudio endpoint API key.,
sambastudio_base_url="your SambaStudio environment URL",
sambastudio_project_id=set with your SambaStudio project ID.,
sambastudio_endpoint_id=set with your SambaStudio endpoint ID.,
sambastudio_api_key= set with your SambaStudio endpoint API key.,
streaming=false
model_kwargs={
"do_sample": False,
@ -634,16 +634,16 @@ class SambaStudio(LLM):
)
"""
base_url: str = ""
sambastudio_base_url: str = ""
"""Base url to use"""
project_id: str = ""
sambastudio_project_id: str = ""
"""Project id on sambastudio for model"""
endpoint_id: str = ""
sambastudio_endpoint_id: str = ""
"""endpoint id on sambastudio for model"""
api_key: str = ""
sambastudio_api_key: str = ""
"""sambastudio api key"""
model_kwargs: Optional[dict] = None
@ -674,16 +674,16 @@ class SambaStudio(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["base_url"] = get_from_dict_or_env(
values["sambastudio_base_url"] = get_from_dict_or_env(
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
)
values["project_id"] = get_from_dict_or_env(
values["sambastudio_project_id"] = get_from_dict_or_env(
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
)
values["endpoint_id"] = get_from_dict_or_env(
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
)
values["api_key"] = get_from_dict_or_env(
values["sambastudio_api_key"] = get_from_dict_or_env(
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
)
return values
@ -729,7 +729,11 @@ class SambaStudio(LLM):
ValueError: If the prediction fails.
"""
response = sdk.nlp_predict(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
)
if response["status_code"] != 200:
optional_detail = response["detail"]
@ -755,7 +759,7 @@ class SambaStudio(LLM):
Raises:
ValueError: If the prediction fails.
"""
ss_endpoint = SSEndpointHandler(self.base_url)
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
@ -774,7 +778,11 @@ class SambaStudio(LLM):
An iterator of GenerationChunks.
"""
for chunk in sdk.nlp_predict_stream(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
):
yield chunk
@ -794,7 +802,7 @@ class SambaStudio(LLM):
Returns:
The string generated by the model.
"""
ss_endpoint = SSEndpointHandler(self.base_url)
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop)
try:
if self.streaming:

@ -77,6 +77,8 @@ EXPECT_ALL = [
"RWKV",
"Replicate",
"SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"StochasticAI",

Loading…
Cancel
Save