diff --git a/docs/docs/integrations/llms/sambanova.ipynb b/docs/docs/integrations/llms/sambanova.ipynb new file mode 100644 index 0000000000..80e2f40012 --- /dev/null +++ b/docs/docs/integrations/llms/sambanova.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sambanova\n", + "\n", + "**[Sambanova](https://sambanova.ai/)'s** [Sambaverse](https://sambaverse.sambanova.ai/) and [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) are platforms for running your own open source models\n", + "\n", + "This example goes over how to use LangChain to interact with Sambanova models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sambaverse" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Sambaverse** allows you to interact with multiple Open source models you can se the list of available models an interact with then in the [playground](https://sambaverse.sambanova.ai/playground)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An API key is required to access to Sambaverse models get one creating an account in [sambaverse.sambanova.ai](https://sambaverse.sambanova.ai/)\n", + "\n", + "The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --quiet sseclient-py==1.8.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Register your API Key environment variable:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "sambaverse_api_key = \"\"\n", + "\n", + "# Set the environment variables\n", + "os.environ[\"SAMBAVERSE_API_KEY\"] = sambaverse_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Call Sambaverse models directly from langchain!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms.sambanova import Sambaverse\n", + "\n", + "llm = Sambaverse(\n", + " sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n", + " streaming=False,\n", + " model_kwargs={\n", + " \"do_sample\": True,\n", + " \"max_tokens_to_generate\": 1000,\n", + " \"temperature\": 0.01,\n", + " \"process_prompt\": True,\n", + " \"select_expert\": \"llama-2-7b-chat-hf\",\n", + " # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n", + " # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n", + " # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n", + " },\n", + ")\n", + "\n", + "print(llm.invoke(\"Why should I use open source models?\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SambaStudio" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**SambaStudio** allows you to Train, run batch inference jous, and deploy online inference endpoints to run your own fine tunned open source models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A SambaStudio environment is required to deploy a model. Get more information in [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)\n", + "\n", + "The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --quiet sseclient-py==1.8.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Register your environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "sambastudio_base_url = \"\"\n", + "sambastudio_project_id = \"\"\n", + "sambastudio_endpoint_id = \"\"\n", + "sambastudio_api_key = \"\"\n", + "\n", + "# Set the environment variables\n", + "os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n", + "os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n", + "os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", + "os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Call SambaStudio models directly from langchain!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms.sambanova import SambaStudio\n", + "\n", + "llm = SambaStudio(\n", + " streaming=False,\n", + " model_kwargs={\n", + " \"do_sample\": True,\n", + " \"max_tokens_to_generate\": 1000,\n", + " \"temperature\": 0.01,\n", + " # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n", + " # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n", + " # \"top_logprobs\": {\"type\": \"int\", \"value\": \"0\"},\n", + " # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n", + " },\n", + ")\n", + "\n", + "print(llm.invoke(\"Why should I use open source models?\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py new file mode 100644 index 0000000000..6e32fc78b8 --- /dev/null +++ b/libs/community/langchain_community/llms/sambanova.py @@ -0,0 +1,865 @@ +import json +from typing import Any, Dict, Generator, Iterator, List, Optional, Union + +import requests +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.utils import get_from_dict_or_env + + +class SVEndpointHandler: + """ + SambaNova Systems Interface for Sambaverse endpoint. + + :param str host_url: Base URL of the DaaS API service + """ + + API_BASE_PATH = "/api/predict" + + def __init__(self, host_url: str): + """ + Initialize the SVEndpointHandler. + + :param str host_url: Base URL of the DaaS API service + """ + self.host_url = host_url + self.http_session = requests.Session() + + @staticmethod + def _process_response(response: requests.Response) -> Dict: + """ + Processes the API response and returns the resulting dict. + + All resulting dicts, regardless of success or failure, will contain the + `status_code` key with the API response status code. + + If the API returned an error, the resulting dict will contain the key + `detail` with the error message. + + If the API call was successful, the resulting dict will contain the key + `data` with the response data. + + :param requests.Response response: the response object to process + :return: the response dict + :rtype: dict + """ + result: Dict[str, Any] = {} + try: + text_result = response.text.strip().split("\n")[-1] + result = {"data": json.loads("".join(text_result.split("data: ")[1:]))} + except Exception as e: + result["detail"] = str(e) + if "status_code" not in result: + result["status_code"] = response.status_code + return result + + @staticmethod + def _process_streaming_response( + response: requests.Response, + ) -> Generator[GenerationChunk, None, None]: + """Process the streaming response""" + try: + import sseclient + except ImportError: + raise ValueError( + "could not import sseclient library" + "Please install it with `pip install sseclient-py`." + ) + client = sseclient.SSEClient(response) + close_conn = False + for event in client.events(): + if event.event == "error_event": + close_conn = True + text = json.dumps({"event": event.event, "data": event.data}) + chunk = GenerationChunk(text=text) + yield chunk + if close_conn: + client.close() + + def _get_full_url(self) -> str: + """ + Return the full API URL for a given path. + :returns: the full API URL for the sub-path + :rtype: str + """ + return f"{self.host_url}{self.API_BASE_PATH}" + + def nlp_predict( + self, + key: str, + sambaverse_model_name: Optional[str], + input: Union[List[str], str], + params: Optional[str] = "", + stream: bool = False, + ) -> Dict: + """ + NLP predict using inline input string. + + :param str project: Project ID in which the endpoint exists + :param str endpoint: Endpoint ID + :param str key: API Key + :param str input_str: Input string + :param str params: Input params string + :returns: Prediction results + :rtype: dict + """ + if isinstance(input, str): + input = [input] + parsed_input = [] + for element in input: + parsed_element = { + "conversation_id": "sambaverse-conversation-id", + "messages": [ + { + "message_id": 0, + "role": "user", + "content": element, + } + ], + } + parsed_input.append(json.dumps(parsed_element)) + if params: + data = {"inputs": parsed_input, "params": json.loads(params)} + else: + data = {"inputs": parsed_input} + response = self.http_session.post( + self._get_full_url(), + headers={ + "key": key, + "Content-Type": "application/json", + "modelName": sambaverse_model_name, + }, + json=data, + ) + return SVEndpointHandler._process_response(response) + + def nlp_predict_stream( + self, + key: str, + sambaverse_model_name: Optional[str], + input: Union[List[str], str], + params: Optional[str] = "", + ) -> Iterator[GenerationChunk]: + """ + NLP predict using inline input string. + + :param str project: Project ID in which the endpoint exists + :param str endpoint: Endpoint ID + :param str key: API Key + :param str input_str: Input string + :param str params: Input params string + :returns: Prediction results + :rtype: dict + """ + if isinstance(input, str): + input = [input] + parsed_input = [] + for element in input: + parsed_element = { + "conversation_id": "sambaverse-conversation-id", + "messages": [ + { + "message_id": 0, + "role": "user", + "content": element, + } + ], + } + parsed_input.append(json.dumps(parsed_element)) + if params: + data = {"inputs": parsed_input, "params": json.loads(params)} + else: + data = {"inputs": parsed_input} + # Streaming output + response = self.http_session.post( + self._get_full_url(), + headers={ + "key": key, + "Content-Type": "application/json", + "modelName": sambaverse_model_name, + }, + json=data, + stream=True, + ) + for chunk in SVEndpointHandler._process_streaming_response(response): + yield chunk + + +class Sambaverse(LLM): + """ + Sambaverse large language models. + + To use, you should have the environment variable ``SAMBAVERSE_API_KEY`` + set with your API key. + + get one in https://sambaverse.sambanova.ai + read extra documentation in https://docs.sambanova.ai/sambaverse/latest/index.html + + + Example: + .. code-block:: python + + from langchain_community.llms.sambanova import Sambaverse + Sambaverse( + sambaverse_url="https://sambaverse.sambanova.ai", + sambaverse_api_key: "your sambaverse api key", + sambaverse_model_name: "Meta/llama-2-7b-chat-hf", + streaming: = False + model_kwargs={ + "do_sample": False, + "max_tokens_to_generate": 100, + "temperature": 0.7, + "top_p": 1.0, + "repetition_penalty": 1, + "top_k": 50, + }, + ) + """ + + sambaverse_url: str = "https://sambaverse.sambanova.ai" + """Sambaverse url to use""" + + sambaverse_api_key: str = "" + """sambaverse api key""" + + sambaverse_model_name: Optional[str] = None + """sambaverse expert model to use""" + + model_kwargs: Optional[dict] = None + """Key word arguments to pass to the model.""" + + streaming: Optional[bool] = False + """Streaming flag to get streamed response.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key exists in environment.""" + values["sambaverse_url"] = get_from_dict_or_env( + values, "sambaverse_url", "SAMBAVERSE_URL" + ) + values["sambaverse_api_key"] = get_from_dict_or_env( + values, "sambaverse_api_key", "SAMBAVERSE_API_KEY" + ) + values["sambaverse_model_name"] = get_from_dict_or_env( + values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME" + ) + return values + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model_kwargs": self.model_kwargs}} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "Sambaverse LLM" + + def _get_tuning_params(self, stop: Optional[List[str]]) -> str: + """ + Get the tuning parameters to use when calling the LLM. + + Args: + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + + Returns: + The tuning parameters as a JSON string. + """ + _model_kwargs = self.model_kwargs or {} + _stop_sequences = _model_kwargs.get("stop_sequences", []) + _stop_sequences = stop or _stop_sequences + _model_kwargs["stop_sequences"] = ",".join(f'"{x}"' for x in _stop_sequences) + tuning_params_dict = { + k: {"type": type(v).__name__, "value": str(v)} + for k, v in (_model_kwargs.items()) + } + tuning_params = json.dumps(tuning_params_dict) + return tuning_params + + def _handle_nlp_predict( + self, + sdk: SVEndpointHandler, + prompt: Union[List[str], str], + tuning_params: str, + ) -> str: + """ + Perform an NLP prediction using the Sambaverse endpoint handler. + + Args: + sdk: The SVEndpointHandler to use for the prediction. + prompt: The prompt to use for the prediction. + tuning_params: The tuning parameters to use for the prediction. + + Returns: + The prediction result. + + Raises: + ValueError: If the prediction fails. + """ + response = sdk.nlp_predict( + self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params + ) + if response["status_code"] != 200: + optional_details = response["details"] + optional_message = response["message"] + raise ValueError( + f"Sambanova /complete call failed with status code " + f"{response['status_code']}. Details: {optional_details}" + f"{response['status_code']}. Message: {optional_message}" + ) + return response["data"]["completion"] + + def _handle_completion_requests( + self, prompt: Union[List[str], str], stop: Optional[List[str]] + ) -> str: + """ + Perform a prediction using the Sambaverse endpoint handler. + + Args: + prompt: The prompt to use for the prediction. + stop: stop sequences. + + Returns: + The prediction result. + + Raises: + ValueError: If the prediction fails. + """ + ss_endpoint = SVEndpointHandler(self.sambaverse_url) + tuning_params = self._get_tuning_params(stop) + return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) + + def _handle_nlp_predict_stream( + self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str + ) -> Iterator[GenerationChunk]: + """ + Perform a streaming request to the LLM. + + Args: + sdk: The SVEndpointHandler to use for the prediction. + prompt: The prompt to use for the prediction. + tuning_params: The tuning parameters to use for the prediction. + + Returns: + An iterator of GenerationChunks. + """ + for chunk in sdk.nlp_predict_stream( + self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params + ): + yield chunk + + def _stream( + self, + prompt: Union[List[str], str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Stream the Sambaverse's LLM on the given prompt. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + run_manager: Callback manager for the run. + **kwargs: Additional keyword arguments. directly passed + to the sambaverse model in API call. + + Returns: + An iterator of GenerationChunks. + """ + ss_endpoint = SVEndpointHandler(self.sambaverse_url) + tuning_params = self._get_tuning_params(stop) + try: + if self.streaming: + for chunk in self._handle_nlp_predict_stream( + ss_endpoint, prompt, tuning_params + ): + if run_manager: + run_manager.on_llm_new_token(chunk.text) + yield chunk + else: + return + except Exception as e: + # Handle any errors raised by the inference endpoint + raise ValueError(f"Error raised by the inference endpoint: {e}") from e + + def _handle_stream_request( + self, + prompt: Union[List[str], str], + stop: Optional[List[str]], + run_manager: Optional[CallbackManagerForLLMRun], + kwargs: Dict[str, Any], + ) -> str: + """ + Perform a streaming request to the LLM. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + run_manager: Callback manager for the run. + **kwargs: Additional keyword arguments. directly passed + to the sambaverse model in API call. + + Returns: + The model output as a string. + """ + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + def _call( + self, + prompt: Union[List[str], str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Run the LLM on the given input. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + run_manager: Callback manager for the run. + **kwargs: Additional keyword arguments. directly passed + to the sambaverse model in API call. + + Returns: + The model output as a string. + """ + try: + if self.streaming: + return self._handle_stream_request(prompt, stop, run_manager, kwargs) + return self._handle_completion_requests(prompt, stop) + except Exception as e: + # Handle any errors raised by the inference endpoint + raise ValueError(f"Error raised by the inference endpoint: {e}") from e + + +class SSEndpointHandler: + """ + SambaNova Systems Interface for SambaStudio model endpoints. + + :param str host_url: Base URL of the DaaS API service + """ + + API_BASE_PATH = "/api" + + def __init__(self, host_url: str): + """ + Initialize the SSEndpointHandler. + + :param str host_url: Base URL of the DaaS API service + """ + self.host_url = host_url + self.http_session = requests.Session() + + @staticmethod + def _process_response(response: requests.Response) -> Dict: + """ + Processes the API response and returns the resulting dict. + + All resulting dicts, regardless of success or failure, will contain the + `status_code` key with the API response status code. + + If the API returned an error, the resulting dict will contain the key + `detail` with the error message. + + If the API call was successful, the resulting dict will contain the key + `data` with the response data. + + :param requests.Response response: the response object to process + :return: the response dict + :rtype: dict + """ + result: Dict[str, Any] = {} + try: + result = response.json() + except Exception as e: + result["detail"] = str(e) + if "status_code" not in result: + result["status_code"] = response.status_code + return result + + @staticmethod + def _process_streaming_response( + response: requests.Response, + ) -> Generator[GenerationChunk, None, None]: + """Process the streaming response""" + try: + import sseclient + except ImportError: + raise ValueError( + "could not import sseclient library" + "Please install it with `pip install sseclient-py`." + ) + client = sseclient.SSEClient(response) + close_conn = False + for event in client.events(): + if event.event == "error_event": + close_conn = True + text = json.dumps({"event": event.event, "data": event.data}) + chunk = GenerationChunk(text=text) + yield chunk + if close_conn: + client.close() + + def _get_full_url(self, path: str) -> str: + """ + Return the full API URL for a given path. + + :param str path: the sub-path + :returns: the full API URL for the sub-path + :rtype: str + """ + return f"{self.host_url}{self.API_BASE_PATH}{path}" + + def nlp_predict( + self, + project: str, + endpoint: str, + key: str, + input: Union[List[str], str], + params: Optional[str] = "", + stream: bool = False, + ) -> Dict: + """ + NLP predict using inline input string. + + :param str project: Project ID in which the endpoint exists + :param str endpoint: Endpoint ID + :param str key: API Key + :param str input_str: Input string + :param str params: Input params string + :returns: Prediction results + :rtype: dict + """ + if isinstance(input, str): + input = [input] + if params: + data = {"inputs": input, "params": json.loads(params)} + else: + data = {"inputs": input} + response = self.http_session.post( + self._get_full_url(f"/predict/nlp/{project}/{endpoint}"), + headers={"key": key}, + json=data, + ) + return SSEndpointHandler._process_response(response) + + def nlp_predict_stream( + self, + project: str, + endpoint: str, + key: str, + input: Union[List[str], str], + params: Optional[str] = "", + ) -> Iterator[GenerationChunk]: + """ + NLP predict using inline input string. + + :param str project: Project ID in which the endpoint exists + :param str endpoint: Endpoint ID + :param str key: API Key + :param str input_str: Input string + :param str params: Input params string + :returns: Prediction results + :rtype: dict + """ + if isinstance(input, str): + input = [input] + if params: + data = {"inputs": input, "params": json.loads(params)} + else: + data = {"inputs": input} + # Streaming output + response = self.http_session.post( + self._get_full_url(f"/predict/nlp/stream/{project}/{endpoint}"), + headers={"key": key}, + json=data, + stream=True, + ) + for chunk in SSEndpointHandler._process_streaming_response(response): + yield chunk + + +class SambaStudio(LLM): + """ + SambaStudio large language models. + + To use, you should have the environment variables + ``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL. + ``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID. + ``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID. + ``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key. + + https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite + + read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html + + Example: + .. code-block:: python + + from langchain_community.llms.sambanova import Sambaverse + SambaStudio( + base_url="your SambaStudio environment URL", + project_id=set with your SambaStudio project ID., + endpoint_id=set with your SambaStudio endpoint ID., + api_token= set with your SambaStudio endpoint API key., + streaming=false + model_kwargs={ + "do_sample": False, + "max_tokens_to_generate": 1000, + "temperature": 0.7, + "top_p": 1.0, + "repetition_penalty": 1, + "top_k": 50, + }, + ) + """ + + base_url: str = "" + """Base url to use""" + + project_id: str = "" + """Project id on sambastudio for model""" + + endpoint_id: str = "" + """endpoint id on sambastudio for model""" + + api_key: str = "" + """sambastudio api key""" + + model_kwargs: Optional[dict] = None + """Key word arguments to pass to the model.""" + + streaming: Optional[bool] = False + """Streaming flag to get streamed response.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model_kwargs": self.model_kwargs}} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "Sambastudio LLM" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["base_url"] = get_from_dict_or_env( + values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL" + ) + values["project_id"] = get_from_dict_or_env( + values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID" + ) + values["endpoint_id"] = get_from_dict_or_env( + values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID" + ) + values["api_key"] = get_from_dict_or_env( + values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY" + ) + return values + + def _get_tuning_params(self, stop: Optional[List[str]]) -> str: + """ + Get the tuning parameters to use when calling the LLM. + + Args: + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + + Returns: + The tuning parameters as a JSON string. + """ + _model_kwargs = self.model_kwargs or {} + _stop_sequences = _model_kwargs.get("stop_sequences", []) + _stop_sequences = stop or _stop_sequences + # _model_kwargs['stop_sequences'] = ','.join( + # f"'{x}'" for x in _stop_sequences) + tuning_params_dict = { + k: {"type": type(v).__name__, "value": str(v)} + for k, v in (_model_kwargs.items()) + } + tuning_params = json.dumps(tuning_params_dict) + return tuning_params + + def _handle_nlp_predict( + self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str + ) -> str: + """ + Perform an NLP prediction using the SambaStudio endpoint handler. + + Args: + sdk: The SSEndpointHandler to use for the prediction. + prompt: The prompt to use for the prediction. + tuning_params: The tuning parameters to use for the prediction. + + Returns: + The prediction result. + + Raises: + ValueError: If the prediction fails. + """ + response = sdk.nlp_predict( + self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params + ) + if response["status_code"] != 200: + optional_detail = response["detail"] + raise ValueError( + f"Sambanova /complete call failed with status code " + f"{response['status_code']}. Details: {optional_detail}" + ) + return response["data"][0]["completion"] + + def _handle_completion_requests( + self, prompt: Union[List[str], str], stop: Optional[List[str]] + ) -> str: + """ + Perform a prediction using the SambaStudio endpoint handler. + + Args: + prompt: The prompt to use for the prediction. + stop: stop sequences. + + Returns: + The prediction result. + + Raises: + ValueError: If the prediction fails. + """ + ss_endpoint = SSEndpointHandler(self.base_url) + tuning_params = self._get_tuning_params(stop) + return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) + + def _handle_nlp_predict_stream( + self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str + ) -> Iterator[GenerationChunk]: + """ + Perform a streaming request to the LLM. + + Args: + sdk: The SVEndpointHandler to use for the prediction. + prompt: The prompt to use for the prediction. + tuning_params: The tuning parameters to use for the prediction. + + Returns: + An iterator of GenerationChunks. + """ + for chunk in sdk.nlp_predict_stream( + self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params + ): + yield chunk + + def _stream( + self, + prompt: Union[List[str], str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Call out to Sambanova's complete endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + """ + ss_endpoint = SSEndpointHandler(self.base_url) + tuning_params = self._get_tuning_params(stop) + try: + if self.streaming: + for chunk in self._handle_nlp_predict_stream( + ss_endpoint, prompt, tuning_params + ): + if run_manager: + run_manager.on_llm_new_token(chunk.text) + yield chunk + else: + return + except Exception as e: + # Handle any errors raised by the inference endpoint + raise ValueError(f"Error raised by the inference endpoint: {e}") from e + + def _handle_stream_request( + self, + prompt: Union[List[str], str], + stop: Optional[List[str]], + run_manager: Optional[CallbackManagerForLLMRun], + kwargs: Dict[str, Any], + ) -> str: + """ + Perform a streaming request to the LLM. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + run_manager: Callback manager for the run. + **kwargs: Additional keyword arguments. directly passed + to the sambaverse model in API call. + + Returns: + The model output as a string. + """ + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + def _call( + self, + prompt: Union[List[str], str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Sambanova's complete endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + """ + if stop is not None: + raise Exception("stop not implemented") + try: + if self.streaming: + return self._handle_stream_request(prompt, stop, run_manager, kwargs) + return self._handle_completion_requests(prompt, stop) + except Exception as e: + # Handle any errors raised by the inference endpoint + raise ValueError(f"Error raised by the inference endpoint: {e}") from e diff --git a/libs/community/tests/integration_tests/llms/test_sambanova.py b/libs/community/tests/integration_tests/llms/test_sambanova.py new file mode 100644 index 0000000000..b5928eb5b6 --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_sambanova.py @@ -0,0 +1,28 @@ +"""Test sambanova API wrapper. + +In order to run this test, you need to have an sambaverse api key, +and a sambaverse base url, project id, endpoint id, and api key. +You'll then need to set SAMBAVERSE_API_KEY, SAMBASTUDIO_BASE_URL, +SAMBASTUDIO_PROJECT_ID, SAMBASTUDIO_ENDPOINT_ID, and SAMBASTUDIO_API_KEY +environment variables. +""" +from langchain_community.llms.sambanova import SambaStudio, Sambaverse + + +def test_sambaverse_call() -> None: + """Test simple non-streaming call to sambaverse.""" + llm = Sambaverse( + sambaverse_model_name="Meta/llama-2-7b-chat-hf", + model_kwargs={"select_expert": "llama-2-7b-chat-hf"}, + ) + output = llm.invoke("What is LangChain") + assert output + assert isinstance(output, str) + + +def test_sambastudio_call() -> None: + """Test simple non-streaming call to sambaverse.""" + llm = SambaStudio() + output = llm.invoke("What is LangChain") + assert output + assert isinstance(output, str)