community: sambastudio llm refactor (#27215)

**Description:** 
    - Sambastudio LLM refactor 
    - Sambastudio openai compatible API support added
    - docs updated
This commit is contained in:
Jorge Piedrahita Ortiz 2024-10-27 10:08:15 -05:00 committed by GitHub
parent fe87e411f2
commit 8895d468cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 681 additions and 649 deletions

View File

@ -1,190 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SambaNova\n",
"\n",
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
"\n",
"This example goes over how to use LangChain to interact with SambaNova models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SambaStudio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**SambaStudio** allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)\n",
"\n",
"The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --quiet sseclient-py==1.8.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register your environment variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
"sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/generic\" set as default\n",
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
"\n",
"# Set the environment variables\n",
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
"os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call SambaStudio models directly from LangChain!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms.sambanova import SambaStudio\n",
"\n",
"llm = SambaStudio(\n",
" streaming=False,\n",
" model_kwargs={\n",
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",
"print(llm.invoke(\"Why should I use open source models?\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Streaming response\n",
"\n",
"from langchain_community.llms.sambanova import SambaStudio\n",
"\n",
"llm = SambaStudio(\n",
" streaming=True,\n",
" model_kwargs={\n",
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also call a CoE endpoint expert model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Using a CoE endpoint\n",
"\n",
"from langchain_community.llms.sambanova import SambaStudio\n",
"\n",
"llm = SambaStudio(\n",
" streaming=False,\n",
" model_kwargs={\n",
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"process_prompt\": False,\n",
" \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",
"print(llm.invoke(\"Why should I use open source models?\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,239 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SambaStudio\n",
"\n",
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform that allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself.\n",
"\n",
":::caution\n",
"You are currently on a page documenting the use of SambaStudio models as [text completion models](/docs/concepts/#llms). We recommend you to use the [chat completion models](/docs/concepts/#chat-models).\n",
"\n",
"You may be looking for [SambaStudio Chat Models](/docs/integrations/chat/sambastudio/) .\n",
":::\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [SambaStudio](https://python.langchain.com/api_reference/community/llms/langchain_community.llms.sambanova.SambaStudio.html) | [langchain_community](https://python.langchain.com/api_reference/community/index.html) | ❌ | beta | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n",
"\n",
"This example goes over how to use LangChain to interact with SambaStudio models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Credentials\n",
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)\n",
"\n",
"you'll need to [deploy an endpoint](https://docs.sambanova.ai/sambastudio/latest/endpoints.html) and set the `SAMBASTUDIO_URL` and `SAMBASTUDIO_API_KEY` environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if \"SAMBASTUDIO_URL\" not in os.environ:\n",
" os.environ[\"SAMBASTUDIO_URL\"] = getpass.getpass()\n",
"if \"SAMBASTUDIO_API_KEY\" not in os.environ:\n",
" os.environ[\"SAMBASTUDIO_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The integration lives in the `langchain-community` package. We also need to install the [sseclient-py](https://pypi.org/project/sseclient-py/) package this is required to run streaming predictions "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --quiet -U langchain-community sseclient-py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms.sambanova import SambaStudio\n",
"\n",
"llm = SambaStudio(\n",
" model_kwargs={\n",
" \"do_sample\": True,\n",
" \"max_tokens\": 1024,\n",
" \"temperature\": 0.01,\n",
" \"process_prompt\": True, # set if using CoE endpoints\n",
" \"model\": \"Meta-Llama-3-70B-Instruct-4096\", # set if using CoE endpoints\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Using open source models can have numerous benefits. Here are some reasons why you should consider using open source models:\\n\\n1. **Cost-effective**: Open source models are often free to use, modify, and distribute, which can significantly reduce costs compared to proprietary models.\\n2. **Customizability**: Open source models can be modified to fit your specific needs, allowing you to tailor the model to your project's requirements.\\n3. **Transparency**: Open source models provide complete transparency into the model's architecture, training data, and algorithms, which can be essential for understanding how the model works and identifying potential biases.\\n4. **Community involvement**: Open source models are often maintained by a community of developers, researchers, and users, which can lead to faster bug fixes, new feature additions, and improved performance.\\n5. **Flexibility**: Open source models can be used in a variety of applications, from research to production, and can be easily integrated into different workflows and systems.\\n6. **Auditability**: With open source models, you can audit the model's performance, data, and algorithms, which is critical in regulated industries or when working with sensitive data.\\n7. **No vendor lock-in**: By using open source models, you're not tied to a specific vendor or proprietary technology, giving you more freedom to switch or modify your approach as needed.\\n8. **Improved security**: Open source models can be reviewed and audited by the community, which can help identify and fix security vulnerabilities more quickly.\\n9. **Access to cutting-edge research**: Open source models can provide access to the latest research and advancements in AI and machine learning, allowing you to leverage the work of experts in the field.\\n10. **Ethical considerations**: By using open source models, you can ensure that your AI systems are transparent, explainable, and fair, which is essential for building trust in AI applications.\\n11. **Reduced risk of bias**: Open source models can help reduce the risk of bias by providing transparency into the model's development, training data, and algorithms.\\n12. **Faster development**: Open source models can accelerate your development process by providing pre-trained models, datasets, and tools that can be easily integrated into your project.\\n13. **Improved collaboration**: Open source models can facilitate collaboration among researchers, developers, and organizations, leading to faster progress and innovation in AI and machine learning.\\n14. **Access to large datasets**: Open source models can provide access to large datasets, which can be essential for training and testing AI models.\\n15. **Compliance with regulations**: In some cases, using open source models can help ensure compliance with regulations, such as GDPR, HIPAA, or CCPA, which require transparency and explainability in AI systems.\\n\\nOverall, using open source models can provide numerous benefits, from cost savings to improved transparency and customizability. By leveraging open source models, you can accelerate your AI and machine learning projects while ensuring that your systems are transparent, explainable, and fair.\""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"Why should I use open source models?\"\n",
"\n",
"completion = llm.invoke(input_text)\n",
"completion"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using open source models can have numerous benefits. Here are some reasons why you should consider using open source models:\n",
"\n",
"1. **Cost-effective**: Open source models are often free to use, modify, and distribute, which can significantly reduce costs compared to proprietary models.\n",
"2. **Customizability**: Open source models can be modified to fit your specific needs, allowing you to tailor the model to your project's requirements.\n",
"3. **Transparency**: Open source models provide complete transparency into the model's architecture, training data, and algorithms, which can be essential for understanding how the model works and identifying potential biases.\n",
"4. **Community involvement**: Open source models are often maintained by a community of developers, researchers, and users, which can lead to faster bug fixes, new feature additions, and improved performance.\n",
"5. **Flexibility**: Open source models can be used in a variety of applications, from research to production, and can be easily integrated into different workflows and systems.\n",
"6. **Auditability**: With open source models, you can audit the model's performance, data, and algorithms, which is critical in regulated industries or when working with sensitive data.\n",
"7. **No vendor lock-in**: By using open source models, you're not tied to a specific vendor or proprietary technology, giving you more freedom to switch or modify your approach as needed.\n",
"8. **Improved security**: Open source models can be reviewed and audited by the community, which can help identify and fix security vulnerabilities more quickly.\n",
"9. **Access to cutting-edge research**: Open source models can provide access to the latest research and advancements in AI and machine learning, allowing you to leverage the work of experts in the field.\n",
"10. **Ethical considerations**: By using open source models, you can ensure that your AI systems are transparent, explainable, and fair, which is essential for building trust in AI applications.\n",
"11. **Reduced risk of bias**: Open source models can help reduce the risk of bias by providing transparency into the model's development, training data, and algorithms.\n",
"12. **Faster development**: Open source models can accelerate your development process by providing pre-trained models, datasets, and tools that can be easily integrated into your project.\n",
"13. **Improved collaboration**: Open source models can facilitate collaboration among researchers, developers, and organizations, leading to faster progress and innovation in AI and machine learning.\n",
"14. **Access to large datasets**: Open source models can provide access to large datasets, which can be essential for training and testing AI models.\n",
"15. **Compliance with regulations**: In some cases, using open source models can help ensure compliance with regulations, such as GDPR, HIPAA, or CCPA, which require transparency and explainability in AI systems.\n",
"\n",
"Overall, using open source models can provide numerous benefits, from cost savings to improved transparency and customizability. By leveraging open source models, you can accelerate your AI and machine learning projects while ensuring that your systems are transparent, explainable, and fair."
]
}
],
"source": [
"# Streaming response\n",
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'In German, you can say:\\n\\n\"Ich liebe das Programmieren.\"\\n\\nHere\\'s a breakdown of the sentence:\\n\\n* \"Ich\" means \"I\"\\n* \"liebe\" is the verb \"to love\" in the present tense, first person singular (I love)\\n* \"das\" is the definite article \"the\"\\n* \"Programmieren\" is the noun \"programming\"\\n\\nSo, \"Ich liebe das Programmieren\" literally means \"I love the programming\".\\n\\nIf you want to make it sound more casual, you can say:\\n\\n\"Ich liebe\\'s Programmieren.\"\\n\\nThe apostrophe in \"liebe\\'s\" is a contraction of \"liebe es\", which is a more informal way of saying \"I love it\".\\n\\nAlternatively, you can also say:\\n\\n\"Programmieren ist meine Leidenschaft.\"\\n\\nThis sentence means \"Programming is my passion\". Here, \"Programmieren\" is the subject, \"ist\" is the verb \"to be\" in the present tense, and \"meine Leidenschaft\" means \"my passion\".'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import PromptTemplate\n",
"\n",
"prompt = PromptTemplate.from_template(\"How to say {input} in {output_language}:\\n\")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all `SambaStudio` llm features and configurations head to the API reference: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.sambanova.SambaStudio.html"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "multimodalenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -1,318 +1,195 @@
import json
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import requests
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import get_from_dict_or_env, pre_init
from pydantic import ConfigDict
class SSEndpointHandler:
"""
SambaNova Systems Interface for SambaStudio model endpoints.
:param str host_url: Base URL of the DaaS API service
"""
def __init__(self, host_url: str, api_base_uri: str):
"""
Initialize the SSEndpointHandler.
:param str host_url: Base URL of the DaaS API service
:param str api_base_uri: Base URI of the DaaS API service
"""
self.host_url = host_url
self.api_base_uri = api_base_uri
self.http_session = requests.Session()
def _process_response(self, response: requests.Response) -> Dict:
"""
Processes the API response and returns the resulting dict.
All resulting dicts, regardless of success or failure, will contain the
`status_code` key with the API response status code.
If the API returned an error, the resulting dict will contain the key
`detail` with the error message.
If the API call was successful, the resulting dict will contain the key
`data` with the response data.
:param requests.Response response: the response object to process
:return: the response dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
result = response.json()
except Exception as e:
result["detail"] = str(e)
if "status_code" not in result:
result["status_code"] = response.status_code
return result
def _process_streaming_response(
self,
response: requests.Response,
) -> Generator[Dict, None, None]:
"""Process the streaming response"""
if "api/predict/nlp" in self.api_base_uri:
try:
import sseclient
except ImportError:
raise ImportError(
"could not import sseclient library"
"Please install it with `pip install sseclient-py`."
)
client = sseclient.SSEClient(response)
close_conn = False
for event in client.events():
if event.event == "error_event":
close_conn = True
chunk = {
"event": event.event,
"data": event.data,
"status_code": response.status_code,
}
yield chunk
if close_conn:
client.close()
elif (
"api/v2/predict/generic" in self.api_base_uri
or "api/predict/generic" in self.api_base_uri
):
try:
for line in response.iter_lines():
chunk = json.loads(line)
if "status_code" not in chunk:
chunk["status_code"] = response.status_code
yield chunk
except Exception as e:
raise RuntimeError(f"Error processing streaming response: {e}")
else:
raise ValueError(
f"handling of endpoint uri: {self.api_base_uri} not implemented"
)
def _get_full_url(self, path: str) -> str:
"""
Return the full API URL for a given path.
:param str path: the sub-path
:returns: the full API URL for the sub-path
:type: str
"""
return f"{self.host_url}/{self.api_base_uri}/{path}"
def nlp_predict(
self,
project: str,
endpoint: str,
key: str,
input: Union[List[str], str],
params: Optional[str] = "",
stream: bool = False,
) -> Dict:
"""
NLP predict using inline input string.
:param str project: Project ID in which the endpoint exists
:param str endpoint: Endpoint ID
:param str key: API Key
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:type: dict
"""
if isinstance(input, str):
input = [input]
if "api/predict/nlp" in self.api_base_uri:
if params:
data = {"inputs": input, "params": json.loads(params)}
else:
data = {"inputs": input}
elif "api/v2/predict/generic" in self.api_base_uri:
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
if params:
data = {"items": items, "params": json.loads(params)}
else:
data = {"items": items}
elif "api/predict/generic" in self.api_base_uri:
if params:
data = {"instances": input, "params": json.loads(params)}
else:
data = {"instances": input}
else:
raise ValueError(
f"handling of endpoint uri: {self.api_base_uri} not implemented"
)
response = self.http_session.post(
self._get_full_url(f"{project}/{endpoint}"),
headers={"key": key},
json=data,
)
return self._process_response(response)
def nlp_predict_stream(
self,
project: str,
endpoint: str,
key: str,
input: Union[List[str], str],
params: Optional[str] = "",
) -> Iterator[Dict]:
"""
NLP predict using inline input string.
:param str project: Project ID in which the endpoint exists
:param str endpoint: Endpoint ID
:param str key: API Key
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:type: dict
"""
if "api/predict/nlp" in self.api_base_uri:
if isinstance(input, str):
input = [input]
if params:
data = {"inputs": input, "params": json.loads(params)}
else:
data = {"inputs": input}
elif "api/v2/predict/generic" in self.api_base_uri:
if isinstance(input, str):
input = [input]
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
if params:
data = {"items": items, "params": json.loads(params)}
else:
data = {"items": items}
elif "api/predict/generic" in self.api_base_uri:
if isinstance(input, list):
input = input[0]
if params:
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": input}
else:
raise ValueError(
f"handling of endpoint uri: {self.api_base_uri} not implemented"
)
# Streaming output
response = self.http_session.post(
self._get_full_url(f"stream/{project}/{endpoint}"),
headers={"key": key},
json=data,
stream=True,
)
for chunk in self._process_streaming_response(response):
yield chunk
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, SecretStr
from requests import Response
class SambaStudio(LLM):
"""
SambaStudio large language models.
To use, you should have the environment variables
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
Setup:
To use, you should have the environment variables
``SAMBASTUDIO_URL`` set with your SambaStudio environment URL.
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
Example:
.. code-block:: python
from langchain_community.llms.sambanova import SambaStudio
SambaStudio(
sambastudio_url="your-SambaStudio-environment-URL",
sambastudio_api_key="your-SambaStudio-API-key,
model_kwargs={
"model" : model or expert name (set for CoE endpoints),
"max_tokens" : max number of tokens to generate,
"temperature" : model temperature,
"top_p" : model top p,
"top_k" : model top k,
"do_sample" : wether to do sample
"process_prompt": wether to process prompt
(set for CoE generic v1 and v2 endpoints)
},
)
Key init args completion params:
model: str
The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096
(set for CoE endpoints).
streaming: bool
Whether to use streaming handler when using non streaming methods
model_kwargs: dict
Extra Key word arguments to pass to the model:
max_tokens: int
max tokens to generate
temperature: float
model temperature
top_p: float
model top p
top_k: int
model top k
do_sample: bool
wether to do sample
process_prompt:
wether to process prompt (set for CoE generic v1 and v2 endpoints)
Key init args client params:
sambastudio_url: str
SambaStudio endpoint Url
sambastudio_api_key: str
SambaStudio endpoint api key
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
Instantiate:
.. code-block:: python
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
from langchain_community.llms import SambaStudio
Example:
.. code-block:: python
llm = SambaStudio=(
sambastudio_url = set with your SambaStudio deployed endpoint URL,
sambastudio_api_key = set with your SambaStudio deployed endpoint Key,
model_kwargs = {
"model" : model or expert name (set for CoE endpoints),
"max_tokens" : max number of tokens to generate,
"temperature" : model temperature,
"top_p" : model top p,
"top_k" : model top k,
"do_sample" : wether to do sample
"process_prompt" : wether to process prompt
(set for CoE generic v1 and v2 endpoints)
}
)
Invoke:
.. code-block:: python
prompt = "tell me a joke"
response = llm.invoke(prompt)
Stream:
.. code-block:: python
for chunk in llm.stream(prompt):
print(chunk, end="", flush=True)
Async:
.. code-block:: python
response = llm.ainvoke(prompt)
await response
from langchain_community.llms.sambanova import SambaStudio
SambaStudio(
sambastudio_base_url="your-SambaStudio-environment-URL",
sambastudio_base_uri="your-SambaStudio-base-URI",
sambastudio_project_id="your-SambaStudio-project-ID",
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
streaming=False
model_kwargs={
"do_sample": False,
"max_tokens_to_generate": 1000,
"temperature": 0.7,
"top_p": 1.0,
"repetition_penalty": 1,
"top_k": 50,
#"process_prompt": False,
#"select_expert": "Meta-Llama-3-8B-Instruct"
},
)
"""
sambastudio_base_url: str = ""
"""Base url to use"""
sambastudio_url: str = Field(default="")
"""SambaStudio Url"""
sambastudio_base_uri: str = ""
"""endpoint base uri"""
sambastudio_api_key: SecretStr = Field(default="")
"""SambaStudio api key"""
sambastudio_project_id: str = ""
"""Project id on sambastudio for model"""
base_url: str = Field(default="", exclude=True)
"""SambaStudio non streaming URL"""
sambastudio_endpoint_id: str = ""
"""endpoint id on sambastudio for model"""
streaming_url: str = Field(default="", exclude=True)
"""SambaStudio streaming URL"""
sambastudio_api_key: str = ""
"""sambastudio api key"""
streaming: bool = Field(default=False)
"""Whether to use streaming handler when using non streaming methods"""
model_kwargs: Optional[dict] = None
model_kwargs: Optional[Dict[str, Any]] = None
"""Key word arguments to pass to the model."""
streaming: Optional[bool] = False
"""Streaming flag to get streamed response."""
model_config = ConfigDict(
extra="forbid",
)
class Config:
populate_by_name = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"sambastudio_url": "sambastudio_url",
"sambastudio_api_key": "sambastudio_api_key",
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_kwargs": self.model_kwargs}}
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {"streaming": self.streaming, **{"model_kwargs": self.model_kwargs}}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "Sambastudio LLM"
return "sambastudio-llm"
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["sambastudio_base_url"] = get_from_dict_or_env(
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
def __init__(self, **kwargs: Any) -> None:
"""init and validate environment variables"""
kwargs["sambastudio_url"] = get_from_dict_or_env(
kwargs, "sambastudio_url", "SAMBASTUDIO_URL"
)
values["sambastudio_base_uri"] = get_from_dict_or_env(
values,
"sambastudio_base_uri",
"SAMBASTUDIO_BASE_URI",
default="api/predict/generic",
)
values["sambastudio_project_id"] = get_from_dict_or_env(
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
)
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
)
values["sambastudio_api_key"] = get_from_dict_or_env(
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
)
return values
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
kwargs["sambastudio_api_key"] = convert_to_secret_str(
get_from_dict_or_env(kwargs, "sambastudio_api_key", "SAMBASTUDIO_API_KEY")
)
kwargs["base_url"], kwargs["streaming_url"] = self._get_sambastudio_urls(
kwargs["sambastudio_url"]
)
super().__init__(**kwargs)
def _get_sambastudio_urls(self, url: str) -> Tuple[str, str]:
"""
Get streaming and non streaming URLs from the given URL
Args:
url: string with sambastudio base or streaming endpoint url
Returns:
base_url: string with url to do non streaming calls
streaming_url: string with url to do streaming calls
"""
if "openai" in url:
base_url = url
stream_url = url
else:
if "stream" in url:
base_url = url.replace("stream/", "")
stream_url = url
else:
base_url = url
if "generic" in url:
stream_url = "generic/stream".join(url.split("generic"))
else:
raise ValueError("Unsupported URL")
return base_url, stream_url
def _get_tuning_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Get the tuning parameters to use when calling the LLM.
@ -321,151 +198,294 @@ class SambaStudio(LLM):
first occurrence of any of the stop substrings.
Returns:
The tuning parameters as a JSON string.
The tuning parameters in the format required by api to use
"""
if stop is None:
stop = []
# get the parameters to use when calling the LLM.
_model_kwargs = self.model_kwargs or {}
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
_stop_sequences = stop or _kwarg_stop_sequences
# if not _kwarg_stop_sequences:
# _model_kwargs["stop_sequences"] = ",".join(
# f'"{x}"' for x in _stop_sequences
# )
if "api/v2/predict/generic" in self.sambastudio_base_uri:
tuning_params_dict = _model_kwargs
else:
tuning_params_dict = {
# handle the case where stop sequences are send in the invocation
# and stop sequences has been also set in the model parameters
_stop_sequences = _model_kwargs.get("stop_sequences", []) + stop
if len(_stop_sequences) > 0:
_model_kwargs["stop_sequences"] = _stop_sequences
# set the parameters structure depending of the API
if "openai" in self.sambastudio_url:
if "select_expert" in _model_kwargs.keys():
_model_kwargs["model"] = _model_kwargs.pop("select_expert")
if "max_tokens_to_generate" in _model_kwargs.keys():
_model_kwargs["max_tokens"] = _model_kwargs.pop(
"max_tokens_to_generate"
)
if "process_prompt" in _model_kwargs.keys():
_model_kwargs.pop("process_prompt")
tuning_params = _model_kwargs
elif "api/v2/predict/generic" in self.sambastudio_url:
if "model" in _model_kwargs.keys():
_model_kwargs["select_expert"] = _model_kwargs.pop("model")
if "max_tokens" in _model_kwargs.keys():
_model_kwargs["max_tokens_to_generate"] = _model_kwargs.pop(
"max_tokens"
)
tuning_params = _model_kwargs
elif "api/predict/generic" in self.sambastudio_url:
if "model" in _model_kwargs.keys():
_model_kwargs["select_expert"] = _model_kwargs.pop("model")
if "max_tokens" in _model_kwargs.keys():
_model_kwargs["max_tokens_to_generate"] = _model_kwargs.pop(
"max_tokens"
)
tuning_params = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
def _handle_nlp_predict(
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
) -> str:
"""
Perform an NLP prediction using the SambaStudio endpoint handler.
Args:
sdk: The SSEndpointHandler to use for the prediction.
prompt: The prompt to use for the prediction.
tuning_params: The tuning parameters to use for the prediction.
Returns:
The prediction result.
Raises:
ValueError: If the prediction fails.
"""
response = sdk.nlp_predict(
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
)
if response["status_code"] != 200:
optional_detail = response.get("detail")
if optional_detail:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n Details: {optional_detail}"
)
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n response {response}"
)
if "api/predict/nlp" in self.sambastudio_base_uri:
return response["data"][0]["completion"]
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
return response["items"][0]["value"]["completion"]
elif "api/predict/generic" in self.sambastudio_base_uri:
return response["predictions"][0]["completion"]
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_base_uri} not implemented"
f"Unsupported URL{self.sambastudio_url}"
"only openai, generic v1 and generic v2 APIs are supported"
)
def _handle_completion_requests(
self, prompt: Union[List[str], str], stop: Optional[List[str]]
) -> str:
return tuning_params
def _handle_request(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]] = None,
streaming: Optional[bool] = False,
) -> Response:
"""
Perform a prediction using the SambaStudio endpoint handler.
Performs a post request to the LLM API.
Args:
prompt: The prompt to use for the prediction.
stop: stop sequences.
prompt: The prompt to pass into the model
stop: list of stop tokens
streaming: wether to do a streaming call
Returns:
The prediction result.
Raises:
ValueError: If the prediction fails.
A request Response object
"""
ss_endpoint = SSEndpointHandler(
self.sambastudio_base_url, self.sambastudio_base_uri
)
tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
def _handle_nlp_predict_stream(
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
) -> Iterator[GenerationChunk]:
"""
Perform a streaming request to the LLM.
if isinstance(prompt, str):
prompt = [prompt]
Args:
sdk: The SVEndpointHandler to use for the prediction.
prompt: The prompt to use for the prediction.
tuning_params: The tuning parameters to use for the prediction.
params = self._get_tuning_params(stop)
Returns:
An iterator of GenerationChunks.
"""
for chunk in sdk.nlp_predict_stream(
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
):
if chunk["status_code"] != 200:
error = chunk.get("error")
if error:
optional_code = error.get("code")
optional_details = error.get("details")
optional_message = error.get("message")
raise ValueError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}.\n"
f"Message: {optional_message}\n"
f"Details: {optional_details}\n"
f"Code: {optional_code}\n"
# create request payload for openAI v1 API
if "openai" in self.sambastudio_url:
messages_dict = [{"role": "user", "content": prompt[0]}]
data = {"messages": messages_dict, "stream": streaming, **params}
data = {key: value for key, value in data.items() if value is not None}
headers = {
"Authorization": f"Bearer "
f"{self.sambastudio_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
# create request payload for generic v1 API
elif "api/v2/predict/generic" in self.sambastudio_url:
if params.get("process_prompt", False):
prompt = json.dumps(
{
"conversation_id": "sambaverse-conversation-id",
"messages": [
{"message_id": None, "role": "user", "content": prompt[0]}
],
}
)
else:
prompt = prompt[0]
items = [{"id": "item0", "value": prompt}]
params = {key: value for key, value in params.items() if value is not None}
data = {"items": items, "params": params}
headers = {"key": self.sambastudio_api_key.get_secret_value()}
# create request payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
if params.get("process_prompt", False):
if params["process_prompt"].get("value") == "True":
prompt = json.dumps(
{
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": None,
"role": "user",
"content": prompt[0],
}
],
}
)
else:
prompt = prompt[0]
else:
prompt = prompt[0]
if streaming:
data = {"instance": prompt, "params": params}
else:
data = {"instances": [prompt], "params": params}
headers = {"key": self.sambastudio_api_key.get_secret_value()}
else:
raise ValueError(
f"Unsupported URL{self.sambastudio_url}"
"only openai, generic v1 and generic v2 APIs are supported"
)
# make the request to SambaStudio API
http_session = requests.Session()
if streaming:
response = http_session.post(
self.streaming_url, headers=headers, json=data, stream=True
)
else:
response = http_session.post(
self.base_url, headers=headers, json=data, stream=False
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova / complete call failed with status code "
f"{response.status_code}."
f"{response.text}."
)
return response
def _process_response(self, response: Response) -> str:
"""
Process a non streaming response from the api
Args:
response: A request Response object
Returns
completion: a string with model generation
"""
# Extract json payload form response
try:
response_dict = response.json()
except Exception as e:
raise RuntimeError(
f"Sambanova /complete call failed couldn't get JSON response {e}"
f"response: {response.text}"
)
# process response payload for openai compatible API
if "openai" in self.sambastudio_url:
completion = response_dict["choices"][0]["message"]["content"]
# process response payload for generic v2 API
elif "api/v2/predict/generic" in self.sambastudio_url:
completion = response_dict["items"][0]["value"]["completion"]
# process response payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
completion = response_dict["predictions"][0]["completion"]
else:
raise ValueError(
f"Unsupported URL{self.sambastudio_url}"
"only openai, generic v1 and generic v2 APIs are supported"
)
return completion
def _process_stream_response(self, response: Response) -> Iterator[GenerationChunk]:
"""
Process a streaming response from the api
Args:
response: An iterable request Response object
Yields:
GenerationChunk: a GenerationChunk with model partial generation
"""
try:
import sseclient
except ImportError:
raise ImportError(
"could not import sseclient library"
"Please install it with `pip install sseclient-py`."
)
# process response payload for openai compatible API
if "openai" in self.sambastudio_url:
client = sseclient.SSEClient(response)
for event in client.events():
if event.event == "error_event":
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"{chunk}."
f"{response.status_code}."
f"{event.data}."
)
if "api/predict/nlp" in self.sambastudio_base_uri:
text = json.loads(chunk["data"])["stream_token"]
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
text = chunk["result"]["items"][0]["value"]["stream_token"]
elif "api/predict/generic" in self.sambastudio_base_uri:
if len(chunk["result"]["responses"]) > 0:
text = chunk["result"]["responses"][0]["stream_token"]
else:
text = ""
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_base_uri}"
f"not implemented"
)
generated_chunk = GenerationChunk(text=text)
yield generated_chunk
try:
# check if the response is not a final event ("[DONE]")
if event.data != "[DONE]":
if isinstance(event.data, str):
data = json.loads(event.data)
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}."
f"{event.data}."
)
if data.get("error"):
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}."
f"{event.data}."
)
if len(data["choices"]) > 0:
content = data["choices"][0]["delta"]["content"]
else:
content = ""
generated_chunk = GenerationChunk(text=content)
yield generated_chunk
except Exception as e:
raise RuntimeError(
f"Error getting content chunk raw streamed response: {e}"
f"data: {event.data}"
)
# process response payload for generic v2 API
elif "api/v2/predict/generic" in self.sambastudio_url:
for line in response.iter_lines():
try:
data = json.loads(line)
content = data["result"]["items"][0]["value"]["stream_token"]
generated_chunk = GenerationChunk(text=content)
yield generated_chunk
except Exception as e:
raise RuntimeError(
f"Error getting content chunk raw streamed response: {e}"
f"line: {line}"
)
# process response payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
for line in response.iter_lines():
try:
data = json.loads(line)
content = data["result"]["responses"][0]["stream_token"]
generated_chunk = GenerationChunk(text=content)
yield generated_chunk
except Exception as e:
raise RuntimeError(
f"Error getting content chunk raw streamed response: {e}"
f"line: {line}"
)
else:
raise ValueError(
f"Unsupported URL{self.sambastudio_url}"
"only openai, generic v1 and generic v2 APIs are supported"
)
def _stream(
self,
@ -478,56 +498,16 @@ class SambaStudio(LLM):
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
stop: a list of strings on which the model should stop generating.
run_manager: A run manager with callbacks for the LLM.
Yields:
chunk: GenerationChunk with model partial generation
"""
ss_endpoint = SSEndpointHandler(
self.sambastudio_base_url, self.sambastudio_base_uri
)
tuning_params = self._get_tuning_params(stop)
try:
if self.streaming:
for chunk in self._handle_nlp_predict_stream(
ss_endpoint, prompt, tuning_params
):
if run_manager:
run_manager.on_llm_new_token(chunk.text)
yield chunk
else:
return
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
def _handle_stream_request(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]],
run_manager: Optional[CallbackManagerForLLMRun],
kwargs: Dict[str, Any],
) -> str:
"""
Perform a streaming request to the LLM.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
run_manager: Callback manager for the run.
kwargs: Additional keyword arguments. directly passed
to the sambastudio model in API call.
Returns:
The model output as a string.
"""
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
response = self._handle_request(prompt, stop, streaming=True)
for chunk in self._process_stream_response(response):
if run_manager:
run_manager.on_llm_new_token(chunk.text)
yield chunk
def _call(
self,
@ -540,17 +520,20 @@ class SambaStudio(LLM):
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
stop: a list of strings on which the model should stop generating.
Returns:
The string generated by the model.
result: string with model generation
"""
if stop is not None:
raise Exception("stop not implemented")
try:
if self.streaming:
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
return self._handle_completion_requests(prompt, stop)
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
response = self._handle_request(prompt, stop, streaming=False)
completion = self._process_response(response)
return completion

View File

@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
# PRs that increase the current count will not be accepted.
# PRs that decrease update the code in the repository
# and allow decreasing the count of are welcome!
current_count=127
current_count=126
if [ "$count" -gt "$current_count" ]; then
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."
@ -52,4 +52,4 @@ if [ "$count" -gt "$current_count" ]; then
elif [ "$count" -lt "$current_count" ]; then
echo "Please update the $current_count variable in ./scripts/check_pydantic.sh to $count"
exit 1
fi
fi