mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
e6207ad4f3
- **Description:**: SambaStudio generic endpoint compatibility added Improved error description, and handling streaming examples added
1011 lines
35 KiB
Python
1011 lines
35 KiB
Python
import json
|
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
|
|
|
|
import requests
|
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.outputs import GenerationChunk
|
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
|
|
class SVEndpointHandler:
|
|
"""
|
|
SambaNova Systems Interface for Sambaverse endpoint.
|
|
|
|
:param str host_url: Base URL of the DaaS API service
|
|
"""
|
|
|
|
API_BASE_PATH = "/api/predict"
|
|
|
|
def __init__(self, host_url: str):
|
|
"""
|
|
Initialize the SVEndpointHandler.
|
|
|
|
:param str host_url: Base URL of the DaaS API service
|
|
"""
|
|
self.host_url = host_url
|
|
self.http_session = requests.Session()
|
|
|
|
@staticmethod
|
|
def _process_response(response: requests.Response) -> Dict:
|
|
"""
|
|
Processes the API response and returns the resulting dict.
|
|
|
|
All resulting dicts, regardless of success or failure, will contain the
|
|
`status_code` key with the API response status code.
|
|
|
|
If the API returned an error, the resulting dict will contain the key
|
|
`detail` with the error message.
|
|
|
|
If the API call was successful, the resulting dict will contain the key
|
|
`data` with the response data.
|
|
|
|
:param requests.Response response: the response object to process
|
|
:return: the response dict
|
|
:rtype: dict
|
|
"""
|
|
result: Dict[str, Any] = {}
|
|
try:
|
|
lines_result = response.text.strip().split("\n")
|
|
text_result = lines_result[-1]
|
|
if response.status_code == 200 and json.loads(text_result).get("error"):
|
|
completion = ""
|
|
for line in lines_result[:-1]:
|
|
completion += json.loads(line)["result"]["responses"][0][
|
|
"stream_token"
|
|
]
|
|
text_result = lines_result[-2]
|
|
result = json.loads(text_result)
|
|
result["result"]["responses"][0]["completion"] = completion
|
|
else:
|
|
result = json.loads(text_result)
|
|
except Exception as e:
|
|
result["detail"] = str(e)
|
|
if "status_code" not in result:
|
|
result["status_code"] = response.status_code
|
|
return result
|
|
|
|
@staticmethod
|
|
def _process_streaming_response(
|
|
response: requests.Response,
|
|
) -> Generator[Dict, None, None]:
|
|
"""Process the streaming response"""
|
|
try:
|
|
for line in response.iter_lines():
|
|
chunk = json.loads(line)
|
|
if "status_code" not in chunk:
|
|
chunk["status_code"] = response.status_code
|
|
if chunk["status_code"] == 200 and chunk.get("error"):
|
|
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
|
return chunk
|
|
yield chunk
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error processing streaming response: {e}")
|
|
|
|
def _get_full_url(self) -> str:
|
|
"""
|
|
Return the full API URL for a given path.
|
|
:returns: the full API URL for the sub-path
|
|
:rtype: str
|
|
"""
|
|
return f"{self.host_url}{self.API_BASE_PATH}"
|
|
|
|
def nlp_predict(
|
|
self,
|
|
key: str,
|
|
sambaverse_model_name: Optional[str],
|
|
input: Union[List[str], str],
|
|
params: Optional[str] = "",
|
|
stream: bool = False,
|
|
) -> Dict:
|
|
"""
|
|
NLP predict using inline input string.
|
|
|
|
:param str project: Project ID in which the endpoint exists
|
|
:param str endpoint: Endpoint ID
|
|
:param str key: API Key
|
|
:param str input_str: Input string
|
|
:param str params: Input params string
|
|
:returns: Prediction results
|
|
:rtype: dict
|
|
"""
|
|
parsed_element = {
|
|
"conversation_id": "sambaverse-conversation-id",
|
|
"messages": [
|
|
{
|
|
"message_id": 0,
|
|
"role": "user",
|
|
"content": input,
|
|
}
|
|
],
|
|
}
|
|
parsed_input = json.dumps(parsed_element)
|
|
if params:
|
|
data = {"instance": parsed_input, "params": json.loads(params)}
|
|
else:
|
|
data = {"instance": parsed_input}
|
|
response = self.http_session.post(
|
|
self._get_full_url(),
|
|
headers={
|
|
"key": key,
|
|
"Content-Type": "application/json",
|
|
"modelName": sambaverse_model_name,
|
|
},
|
|
json=data,
|
|
)
|
|
return SVEndpointHandler._process_response(response)
|
|
|
|
def nlp_predict_stream(
|
|
self,
|
|
key: str,
|
|
sambaverse_model_name: Optional[str],
|
|
input: Union[List[str], str],
|
|
params: Optional[str] = "",
|
|
) -> Iterator[Dict]:
|
|
"""
|
|
NLP predict using inline input string.
|
|
|
|
:param str project: Project ID in which the endpoint exists
|
|
:param str endpoint: Endpoint ID
|
|
:param str key: API Key
|
|
:param str input_str: Input string
|
|
:param str params: Input params string
|
|
:returns: Prediction results
|
|
:rtype: dict
|
|
"""
|
|
parsed_element = {
|
|
"conversation_id": "sambaverse-conversation-id",
|
|
"messages": [
|
|
{
|
|
"message_id": 0,
|
|
"role": "user",
|
|
"content": input,
|
|
}
|
|
],
|
|
}
|
|
parsed_input = json.dumps(parsed_element)
|
|
if params:
|
|
data = {"instance": parsed_input, "params": json.loads(params)}
|
|
else:
|
|
data = {"instance": parsed_input}
|
|
# Streaming output
|
|
response = self.http_session.post(
|
|
self._get_full_url(),
|
|
headers={
|
|
"key": key,
|
|
"Content-Type": "application/json",
|
|
"modelName": sambaverse_model_name,
|
|
},
|
|
json=data,
|
|
stream=True,
|
|
)
|
|
for chunk in SVEndpointHandler._process_streaming_response(response):
|
|
yield chunk
|
|
|
|
|
|
class Sambaverse(LLM):
|
|
"""
|
|
Sambaverse large language models.
|
|
|
|
To use, you should have the environment variable ``SAMBAVERSE_API_KEY``
|
|
set with your API key.
|
|
|
|
get one in https://sambaverse.sambanova.ai
|
|
read extra documentation in https://docs.sambanova.ai/sambaverse/latest/index.html
|
|
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms.sambanova import Sambaverse
|
|
Sambaverse(
|
|
sambaverse_url="https://sambaverse.sambanova.ai",
|
|
sambaverse_api_key="your-sambaverse-api-key",
|
|
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
|
streaming: = False
|
|
model_kwargs={
|
|
"select_expert": "llama-2-7b-chat-hf",
|
|
"do_sample": False,
|
|
"max_tokens_to_generate": 100,
|
|
"temperature": 0.7,
|
|
"top_p": 1.0,
|
|
"repetition_penalty": 1.0,
|
|
"top_k": 50,
|
|
},
|
|
)
|
|
"""
|
|
|
|
sambaverse_url: str = ""
|
|
"""Sambaverse url to use"""
|
|
|
|
sambaverse_api_key: str = ""
|
|
"""sambaverse api key"""
|
|
|
|
sambaverse_model_name: Optional[str] = None
|
|
"""sambaverse expert model to use"""
|
|
|
|
model_kwargs: Optional[dict] = None
|
|
"""Key word arguments to pass to the model."""
|
|
|
|
streaming: Optional[bool] = False
|
|
"""Streaming flag to get streamed response."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
return True
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key exists in environment."""
|
|
values["sambaverse_url"] = get_from_dict_or_env(
|
|
values,
|
|
"sambaverse_url",
|
|
"SAMBAVERSE_URL",
|
|
default="https://sambaverse.sambanova.ai",
|
|
)
|
|
values["sambaverse_api_key"] = get_from_dict_or_env(
|
|
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
|
|
)
|
|
values["sambaverse_model_name"] = get_from_dict_or_env(
|
|
values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME"
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {**{"model_kwargs": self.model_kwargs}}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "Sambaverse LLM"
|
|
|
|
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
|
"""
|
|
Get the tuning parameters to use when calling the LLM.
|
|
|
|
Args:
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of the stop substrings.
|
|
|
|
Returns:
|
|
The tuning parameters as a JSON string.
|
|
"""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
|
_stop_sequences = stop or _kwarg_stop_sequences
|
|
if not _kwarg_stop_sequences:
|
|
_model_kwargs["stop_sequences"] = ",".join(
|
|
f'"{x}"' for x in _stop_sequences
|
|
)
|
|
tuning_params_dict = {
|
|
k: {"type": type(v).__name__, "value": str(v)}
|
|
for k, v in (_model_kwargs.items())
|
|
}
|
|
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
|
tuning_params = json.dumps(tuning_params_dict)
|
|
return tuning_params
|
|
|
|
def _handle_nlp_predict(
|
|
self,
|
|
sdk: SVEndpointHandler,
|
|
prompt: Union[List[str], str],
|
|
tuning_params: str,
|
|
) -> str:
|
|
"""
|
|
Perform an NLP prediction using the Sambaverse endpoint handler.
|
|
|
|
Args:
|
|
sdk: The SVEndpointHandler to use for the prediction.
|
|
prompt: The prompt to use for the prediction.
|
|
tuning_params: The tuning parameters to use for the prediction.
|
|
|
|
Returns:
|
|
The prediction result.
|
|
|
|
Raises:
|
|
ValueError: If the prediction fails.
|
|
"""
|
|
response = sdk.nlp_predict(
|
|
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
|
)
|
|
if response["status_code"] != 200:
|
|
error = response.get("error")
|
|
if error:
|
|
optional_code = error.get("code")
|
|
optional_details = error.get("details")
|
|
optional_message = error.get("message")
|
|
raise RuntimeError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{response['status_code']}.\n"
|
|
f"Message: {optional_message}\n"
|
|
f"Details: {optional_details}\n"
|
|
f"Code: {optional_code}\n"
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{response['status_code']}."
|
|
f"{response}."
|
|
)
|
|
return response["result"]["responses"][0]["completion"]
|
|
|
|
def _handle_completion_requests(
|
|
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
|
) -> str:
|
|
"""
|
|
Perform a prediction using the Sambaverse endpoint handler.
|
|
|
|
Args:
|
|
prompt: The prompt to use for the prediction.
|
|
stop: stop sequences.
|
|
|
|
Returns:
|
|
The prediction result.
|
|
|
|
Raises:
|
|
ValueError: If the prediction fails.
|
|
"""
|
|
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
|
tuning_params = self._get_tuning_params(stop)
|
|
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
|
|
|
def _handle_nlp_predict_stream(
|
|
self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
|
) -> Iterator[GenerationChunk]:
|
|
"""
|
|
Perform a streaming request to the LLM.
|
|
|
|
Args:
|
|
sdk: The SVEndpointHandler to use for the prediction.
|
|
prompt: The prompt to use for the prediction.
|
|
tuning_params: The tuning parameters to use for the prediction.
|
|
|
|
Returns:
|
|
An iterator of GenerationChunks.
|
|
"""
|
|
for chunk in sdk.nlp_predict_stream(
|
|
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
|
):
|
|
if chunk["status_code"] != 200:
|
|
error = chunk.get("error")
|
|
if error:
|
|
optional_code = error.get("code")
|
|
optional_details = error.get("details")
|
|
optional_message = error.get("message")
|
|
raise ValueError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{chunk['status_code']}.\n"
|
|
f"Message: {optional_message}\n"
|
|
f"Details: {optional_details}\n"
|
|
f"Code: {optional_code}\n"
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{chunk['status_code']}."
|
|
f"{chunk}."
|
|
)
|
|
text = chunk["result"]["responses"][0]["stream_token"]
|
|
generated_chunk = GenerationChunk(text=text)
|
|
yield generated_chunk
|
|
|
|
def _stream(
|
|
self,
|
|
prompt: Union[List[str], str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[GenerationChunk]:
|
|
"""Stream the Sambaverse's LLM on the given prompt.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
run_manager: Callback manager for the run.
|
|
**kwargs: Additional keyword arguments. directly passed
|
|
to the sambaverse model in API call.
|
|
|
|
Returns:
|
|
An iterator of GenerationChunks.
|
|
"""
|
|
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
|
tuning_params = self._get_tuning_params(stop)
|
|
try:
|
|
if self.streaming:
|
|
for chunk in self._handle_nlp_predict_stream(
|
|
ss_endpoint, prompt, tuning_params
|
|
):
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(chunk.text)
|
|
yield chunk
|
|
else:
|
|
return
|
|
except Exception as e:
|
|
# Handle any errors raised by the inference endpoint
|
|
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
|
|
|
def _handle_stream_request(
|
|
self,
|
|
prompt: Union[List[str], str],
|
|
stop: Optional[List[str]],
|
|
run_manager: Optional[CallbackManagerForLLMRun],
|
|
kwargs: Dict[str, Any],
|
|
) -> str:
|
|
"""
|
|
Perform a streaming request to the LLM.
|
|
|
|
Args:
|
|
prompt: The prompt to generate from.
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of the stop substrings.
|
|
run_manager: Callback manager for the run.
|
|
**kwargs: Additional keyword arguments. directly passed
|
|
to the sambaverse model in API call.
|
|
|
|
Returns:
|
|
The model output as a string.
|
|
"""
|
|
completion = ""
|
|
for chunk in self._stream(
|
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
|
):
|
|
completion += chunk.text
|
|
return completion
|
|
|
|
def _call(
|
|
self,
|
|
prompt: Union[List[str], str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Run the LLM on the given input.
|
|
|
|
Args:
|
|
prompt: The prompt to generate from.
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of the stop substrings.
|
|
run_manager: Callback manager for the run.
|
|
**kwargs: Additional keyword arguments. directly passed
|
|
to the sambaverse model in API call.
|
|
|
|
Returns:
|
|
The model output as a string.
|
|
"""
|
|
try:
|
|
if self.streaming:
|
|
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
|
return self._handle_completion_requests(prompt, stop)
|
|
except Exception as e:
|
|
# Handle any errors raised by the inference endpoint
|
|
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
|
|
|
|
|
class SSEndpointHandler:
|
|
"""
|
|
SambaNova Systems Interface for SambaStudio model endpoints.
|
|
|
|
:param str host_url: Base URL of the DaaS API service
|
|
"""
|
|
|
|
def __init__(self, host_url: str, api_base_uri: str):
|
|
"""
|
|
Initialize the SSEndpointHandler.
|
|
|
|
:param str host_url: Base URL of the DaaS API service
|
|
:param str api_base_uri: Base URI of the DaaS API service
|
|
"""
|
|
self.host_url = host_url
|
|
self.api_base_uri = api_base_uri
|
|
self.http_session = requests.Session()
|
|
|
|
def _process_response(self, response: requests.Response) -> Dict:
|
|
"""
|
|
Processes the API response and returns the resulting dict.
|
|
|
|
All resulting dicts, regardless of success or failure, will contain the
|
|
`status_code` key with the API response status code.
|
|
|
|
If the API returned an error, the resulting dict will contain the key
|
|
`detail` with the error message.
|
|
|
|
If the API call was successful, the resulting dict will contain the key
|
|
`data` with the response data.
|
|
|
|
:param requests.Response response: the response object to process
|
|
:return: the response dict
|
|
:rtype: dict
|
|
"""
|
|
result: Dict[str, Any] = {}
|
|
try:
|
|
result = response.json()
|
|
except Exception as e:
|
|
result["detail"] = str(e)
|
|
if "status_code" not in result:
|
|
result["status_code"] = response.status_code
|
|
return result
|
|
|
|
def _process_streaming_response(
|
|
self,
|
|
response: requests.Response,
|
|
) -> Generator[Dict, None, None]:
|
|
"""Process the streaming response"""
|
|
if "nlp" in self.api_base_uri:
|
|
try:
|
|
import sseclient
|
|
except ImportError:
|
|
raise ImportError(
|
|
"could not import sseclient library"
|
|
"Please install it with `pip install sseclient-py`."
|
|
)
|
|
client = sseclient.SSEClient(response)
|
|
close_conn = False
|
|
for event in client.events():
|
|
if event.event == "error_event":
|
|
close_conn = True
|
|
chunk = {
|
|
"event": event.event,
|
|
"data": event.data,
|
|
"status_code": response.status_code,
|
|
}
|
|
yield chunk
|
|
if close_conn:
|
|
client.close()
|
|
elif "generic" in self.api_base_uri:
|
|
try:
|
|
for line in response.iter_lines():
|
|
chunk = json.loads(line)
|
|
if "status_code" not in chunk:
|
|
chunk["status_code"] = response.status_code
|
|
if chunk["status_code"] == 200 and chunk.get("error"):
|
|
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
|
yield chunk
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error processing streaming response: {e}")
|
|
else:
|
|
raise ValueError(
|
|
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
|
)
|
|
|
|
def _get_full_url(self, path: str) -> str:
|
|
"""
|
|
Return the full API URL for a given path.
|
|
|
|
:param str path: the sub-path
|
|
:returns: the full API URL for the sub-path
|
|
:rtype: str
|
|
"""
|
|
return f"{self.host_url}/{self.api_base_uri}/{path}"
|
|
|
|
def nlp_predict(
|
|
self,
|
|
project: str,
|
|
endpoint: str,
|
|
key: str,
|
|
input: Union[List[str], str],
|
|
params: Optional[str] = "",
|
|
stream: bool = False,
|
|
) -> Dict:
|
|
"""
|
|
NLP predict using inline input string.
|
|
|
|
:param str project: Project ID in which the endpoint exists
|
|
:param str endpoint: Endpoint ID
|
|
:param str key: API Key
|
|
:param str input_str: Input string
|
|
:param str params: Input params string
|
|
:returns: Prediction results
|
|
:rtype: dict
|
|
"""
|
|
if isinstance(input, str):
|
|
input = [input]
|
|
if "nlp" in self.api_base_uri:
|
|
if params:
|
|
data = {"inputs": input, "params": json.loads(params)}
|
|
else:
|
|
data = {"inputs": input}
|
|
elif "generic" in self.api_base_uri:
|
|
if params:
|
|
data = {"instances": input, "params": json.loads(params)}
|
|
else:
|
|
data = {"instances": input}
|
|
else:
|
|
raise ValueError(
|
|
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
|
)
|
|
response = self.http_session.post(
|
|
self._get_full_url(f"{project}/{endpoint}"),
|
|
headers={"key": key},
|
|
json=data,
|
|
)
|
|
return self._process_response(response)
|
|
|
|
def nlp_predict_stream(
|
|
self,
|
|
project: str,
|
|
endpoint: str,
|
|
key: str,
|
|
input: Union[List[str], str],
|
|
params: Optional[str] = "",
|
|
) -> Iterator[Dict]:
|
|
"""
|
|
NLP predict using inline input string.
|
|
|
|
:param str project: Project ID in which the endpoint exists
|
|
:param str endpoint: Endpoint ID
|
|
:param str key: API Key
|
|
:param str input_str: Input string
|
|
:param str params: Input params string
|
|
:returns: Prediction results
|
|
:rtype: dict
|
|
"""
|
|
if "nlp" in self.api_base_uri:
|
|
if isinstance(input, str):
|
|
input = [input]
|
|
if params:
|
|
data = {"inputs": input, "params": json.loads(params)}
|
|
else:
|
|
data = {"inputs": input}
|
|
elif "generic" in self.api_base_uri:
|
|
if isinstance(input, list):
|
|
input = input[0]
|
|
if params:
|
|
data = {"instance": input, "params": json.loads(params)}
|
|
else:
|
|
data = {"instance": input}
|
|
else:
|
|
raise ValueError(
|
|
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
|
)
|
|
# Streaming output
|
|
response = self.http_session.post(
|
|
self._get_full_url(f"stream/{project}/{endpoint}"),
|
|
headers={"key": key},
|
|
json=data,
|
|
stream=True,
|
|
)
|
|
for chunk in self._process_streaming_response(response):
|
|
yield chunk
|
|
|
|
|
|
class SambaStudio(LLM):
|
|
"""
|
|
SambaStudio large language models.
|
|
|
|
To use, you should have the environment variables
|
|
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
|
|
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
|
|
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
|
|
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
|
|
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
|
|
|
|
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
|
|
|
|
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms.sambanova import Sambaverse
|
|
SambaStudio(
|
|
sambastudio_base_url="your-SambaStudio-environment-URL",
|
|
sambastudio_base_uri="your-SambaStudio-base-URI",
|
|
sambastudio_project_id="your-SambaStudio-project-ID",
|
|
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
|
|
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
|
|
streaming=False
|
|
model_kwargs={
|
|
"do_sample": False,
|
|
"max_tokens_to_generate": 1000,
|
|
"temperature": 0.7,
|
|
"top_p": 1.0,
|
|
"repetition_penalty": 1,
|
|
"top_k": 50,
|
|
},
|
|
)
|
|
"""
|
|
|
|
sambastudio_base_url: str = ""
|
|
"""Base url to use"""
|
|
|
|
sambastudio_base_uri: str = ""
|
|
"""endpoint base uri"""
|
|
|
|
sambastudio_project_id: str = ""
|
|
"""Project id on sambastudio for model"""
|
|
|
|
sambastudio_endpoint_id: str = ""
|
|
"""endpoint id on sambastudio for model"""
|
|
|
|
sambastudio_api_key: str = ""
|
|
"""sambastudio api key"""
|
|
|
|
model_kwargs: Optional[dict] = None
|
|
"""Key word arguments to pass to the model."""
|
|
|
|
streaming: Optional[bool] = False
|
|
"""Streaming flag to get streamed response."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {**{"model_kwargs": self.model_kwargs}}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "Sambastudio LLM"
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
values["sambastudio_base_url"] = get_from_dict_or_env(
|
|
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
|
)
|
|
values["sambastudio_base_uri"] = get_from_dict_or_env(
|
|
values,
|
|
"sambastudio_base_uri",
|
|
"SAMBASTUDIO_BASE_URI",
|
|
default="api/predict/nlp",
|
|
)
|
|
values["sambastudio_project_id"] = get_from_dict_or_env(
|
|
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
|
)
|
|
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
|
|
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
|
|
)
|
|
values["sambastudio_api_key"] = get_from_dict_or_env(
|
|
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
|
|
)
|
|
return values
|
|
|
|
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
|
"""
|
|
Get the tuning parameters to use when calling the LLM.
|
|
|
|
Args:
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of the stop substrings.
|
|
|
|
Returns:
|
|
The tuning parameters as a JSON string.
|
|
"""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
|
_stop_sequences = stop or _kwarg_stop_sequences
|
|
# if not _kwarg_stop_sequences:
|
|
# _model_kwargs["stop_sequences"] = ",".join(
|
|
# f'"{x}"' for x in _stop_sequences
|
|
# )
|
|
tuning_params_dict = {
|
|
k: {"type": type(v).__name__, "value": str(v)}
|
|
for k, v in (_model_kwargs.items())
|
|
}
|
|
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
|
tuning_params = json.dumps(tuning_params_dict)
|
|
return tuning_params
|
|
|
|
def _handle_nlp_predict(
|
|
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
|
) -> str:
|
|
"""
|
|
Perform an NLP prediction using the SambaStudio endpoint handler.
|
|
|
|
Args:
|
|
sdk: The SSEndpointHandler to use for the prediction.
|
|
prompt: The prompt to use for the prediction.
|
|
tuning_params: The tuning parameters to use for the prediction.
|
|
|
|
Returns:
|
|
The prediction result.
|
|
|
|
Raises:
|
|
ValueError: If the prediction fails.
|
|
"""
|
|
response = sdk.nlp_predict(
|
|
self.sambastudio_project_id,
|
|
self.sambastudio_endpoint_id,
|
|
self.sambastudio_api_key,
|
|
prompt,
|
|
tuning_params,
|
|
)
|
|
if response["status_code"] != 200:
|
|
optional_detail = response.get("detail")
|
|
if optional_detail:
|
|
raise RuntimeError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{response['status_code']}.\n Details: {optional_detail}"
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{response['status_code']}.\n response {response}"
|
|
)
|
|
if "nlp" in self.sambastudio_base_uri:
|
|
return response["data"][0]["completion"]
|
|
elif "generic" in self.sambastudio_base_uri:
|
|
return response["predictions"][0]["completion"]
|
|
else:
|
|
raise ValueError(
|
|
f"handling of endpoint uri: {self.sambastudio_base_uri} not implemented"
|
|
)
|
|
|
|
def _handle_completion_requests(
|
|
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
|
) -> str:
|
|
"""
|
|
Perform a prediction using the SambaStudio endpoint handler.
|
|
|
|
Args:
|
|
prompt: The prompt to use for the prediction.
|
|
stop: stop sequences.
|
|
|
|
Returns:
|
|
The prediction result.
|
|
|
|
Raises:
|
|
ValueError: If the prediction fails.
|
|
"""
|
|
ss_endpoint = SSEndpointHandler(
|
|
self.sambastudio_base_url, self.sambastudio_base_uri
|
|
)
|
|
tuning_params = self._get_tuning_params(stop)
|
|
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
|
|
|
def _handle_nlp_predict_stream(
|
|
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
|
) -> Iterator[GenerationChunk]:
|
|
"""
|
|
Perform a streaming request to the LLM.
|
|
|
|
Args:
|
|
sdk: The SVEndpointHandler to use for the prediction.
|
|
prompt: The prompt to use for the prediction.
|
|
tuning_params: The tuning parameters to use for the prediction.
|
|
|
|
Returns:
|
|
An iterator of GenerationChunks.
|
|
"""
|
|
for chunk in sdk.nlp_predict_stream(
|
|
self.sambastudio_project_id,
|
|
self.sambastudio_endpoint_id,
|
|
self.sambastudio_api_key,
|
|
prompt,
|
|
tuning_params,
|
|
):
|
|
if chunk["status_code"] != 200:
|
|
error = chunk.get("error")
|
|
if error:
|
|
optional_code = error.get("code")
|
|
optional_details = error.get("details")
|
|
optional_message = error.get("message")
|
|
raise ValueError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{chunk['status_code']}.\n"
|
|
f"Message: {optional_message}\n"
|
|
f"Details: {optional_details}\n"
|
|
f"Code: {optional_code}\n"
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Sambanova /complete call failed with status code "
|
|
f"{chunk['status_code']}."
|
|
f"{chunk}."
|
|
)
|
|
if "nlp" in self.sambastudio_base_uri:
|
|
text = json.loads(chunk["data"])["stream_token"]
|
|
elif "generic" in self.sambastudio_base_uri:
|
|
text = chunk["result"]["responses"][0]["stream_token"]
|
|
else:
|
|
raise ValueError(
|
|
f"handling of endpoint uri: {self.sambastudio_base_uri}"
|
|
f"not implemented"
|
|
)
|
|
generated_chunk = GenerationChunk(text=text)
|
|
yield generated_chunk
|
|
|
|
def _stream(
|
|
self,
|
|
prompt: Union[List[str], str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[GenerationChunk]:
|
|
"""Call out to Sambanova's complete endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
"""
|
|
ss_endpoint = SSEndpointHandler(
|
|
self.sambastudio_base_url, self.sambastudio_base_uri
|
|
)
|
|
tuning_params = self._get_tuning_params(stop)
|
|
try:
|
|
if self.streaming:
|
|
for chunk in self._handle_nlp_predict_stream(
|
|
ss_endpoint, prompt, tuning_params
|
|
):
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(chunk.text)
|
|
yield chunk
|
|
else:
|
|
return
|
|
except Exception as e:
|
|
# Handle any errors raised by the inference endpoint
|
|
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
|
|
|
def _handle_stream_request(
|
|
self,
|
|
prompt: Union[List[str], str],
|
|
stop: Optional[List[str]],
|
|
run_manager: Optional[CallbackManagerForLLMRun],
|
|
kwargs: Dict[str, Any],
|
|
) -> str:
|
|
"""
|
|
Perform a streaming request to the LLM.
|
|
|
|
Args:
|
|
prompt: The prompt to generate from.
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of the stop substrings.
|
|
run_manager: Callback manager for the run.
|
|
**kwargs: Additional keyword arguments. directly passed
|
|
to the sambaverse model in API call.
|
|
|
|
Returns:
|
|
The model output as a string.
|
|
"""
|
|
completion = ""
|
|
for chunk in self._stream(
|
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
|
):
|
|
completion += chunk.text
|
|
return completion
|
|
|
|
def _call(
|
|
self,
|
|
prompt: Union[List[str], str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call out to Sambanova's complete endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
"""
|
|
if stop is not None:
|
|
raise Exception("stop not implemented")
|
|
try:
|
|
if self.streaming:
|
|
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
|
return self._handle_completion_requests(prompt, stop)
|
|
except Exception as e:
|
|
# Handle any errors raised by the inference endpoint
|
|
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|