mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
community: sambastudio llm refactor (#27215)
**Description:** - Sambastudio LLM refactor - Sambastudio openai compatible API support added - docs updated
This commit is contained in:
parent
fe87e411f2
commit
8895d468cb
@ -1,190 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SambaNova\n",
|
||||
"\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with SambaNova models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SambaStudio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**SambaStudio** allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)\n",
|
||||
"\n",
|
||||
"The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --quiet sseclient-py==1.8.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register your environment variables:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
||||
"sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/generic\" set as default\n",
|
||||
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
||||
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
||||
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
||||
"\n",
|
||||
"# Set the environment variables\n",
|
||||
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
|
||||
"os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
|
||||
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
|
||||
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
||||
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call SambaStudio models directly from LangChain!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||
"\n",
|
||||
"llm = SambaStudio(\n",
|
||||
" streaming=False,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||
"\n",
|
||||
"llm = SambaStudio(\n",
|
||||
" streaming=True,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also call a CoE endpoint expert model "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using a CoE endpoint\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||
"\n",
|
||||
"llm = SambaStudio(\n",
|
||||
" streaming=False,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
239
docs/docs/integrations/llms/sambastudio.ipynb
Normal file
239
docs/docs/integrations/llms/sambastudio.ipynb
Normal file
@ -0,0 +1,239 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SambaStudio\n",
|
||||
"\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform that allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself.\n",
|
||||
"\n",
|
||||
":::caution\n",
|
||||
"You are currently on a page documenting the use of SambaStudio models as [text completion models](/docs/concepts/#llms). We recommend you to use the [chat completion models](/docs/concepts/#chat-models).\n",
|
||||
"\n",
|
||||
"You may be looking for [SambaStudio Chat Models](/docs/integrations/chat/sambastudio/) .\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"### Integration details\n",
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| [SambaStudio](https://python.langchain.com/api_reference/community/llms/langchain_community.llms.sambanova.SambaStudio.html) | [langchain_community](https://python.langchain.com/api_reference/community/index.html) | ❌ | beta | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with SambaStudio models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"### Credentials\n",
|
||||
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)\n",
|
||||
"\n",
|
||||
"you'll need to [deploy an endpoint](https://docs.sambanova.ai/sambastudio/latest/endpoints.html) and set the `SAMBASTUDIO_URL` and `SAMBASTUDIO_API_KEY` environment variables:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if \"SAMBASTUDIO_URL\" not in os.environ:\n",
|
||||
" os.environ[\"SAMBASTUDIO_URL\"] = getpass.getpass()\n",
|
||||
"if \"SAMBASTUDIO_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"SAMBASTUDIO_API_KEY\"] = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"\n",
|
||||
"The integration lives in the `langchain-community` package. We also need to install the [sseclient-py](https://pypi.org/project/sseclient-py/) package this is required to run streaming predictions "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --quiet -U langchain-community sseclient-py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Instantiation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||
"\n",
|
||||
"llm = SambaStudio(\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens\": 1024,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"process_prompt\": True, # set if using CoE endpoints\n",
|
||||
" \"model\": \"Meta-Llama-3-70B-Instruct-4096\", # set if using CoE endpoints\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation\n",
|
||||
"\n",
|
||||
"Now we can instantiate our model object and generate chat completions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Using open source models can have numerous benefits. Here are some reasons why you should consider using open source models:\\n\\n1. **Cost-effective**: Open source models are often free to use, modify, and distribute, which can significantly reduce costs compared to proprietary models.\\n2. **Customizability**: Open source models can be modified to fit your specific needs, allowing you to tailor the model to your project's requirements.\\n3. **Transparency**: Open source models provide complete transparency into the model's architecture, training data, and algorithms, which can be essential for understanding how the model works and identifying potential biases.\\n4. **Community involvement**: Open source models are often maintained by a community of developers, researchers, and users, which can lead to faster bug fixes, new feature additions, and improved performance.\\n5. **Flexibility**: Open source models can be used in a variety of applications, from research to production, and can be easily integrated into different workflows and systems.\\n6. **Auditability**: With open source models, you can audit the model's performance, data, and algorithms, which is critical in regulated industries or when working with sensitive data.\\n7. **No vendor lock-in**: By using open source models, you're not tied to a specific vendor or proprietary technology, giving you more freedom to switch or modify your approach as needed.\\n8. **Improved security**: Open source models can be reviewed and audited by the community, which can help identify and fix security vulnerabilities more quickly.\\n9. **Access to cutting-edge research**: Open source models can provide access to the latest research and advancements in AI and machine learning, allowing you to leverage the work of experts in the field.\\n10. **Ethical considerations**: By using open source models, you can ensure that your AI systems are transparent, explainable, and fair, which is essential for building trust in AI applications.\\n11. **Reduced risk of bias**: Open source models can help reduce the risk of bias by providing transparency into the model's development, training data, and algorithms.\\n12. **Faster development**: Open source models can accelerate your development process by providing pre-trained models, datasets, and tools that can be easily integrated into your project.\\n13. **Improved collaboration**: Open source models can facilitate collaboration among researchers, developers, and organizations, leading to faster progress and innovation in AI and machine learning.\\n14. **Access to large datasets**: Open source models can provide access to large datasets, which can be essential for training and testing AI models.\\n15. **Compliance with regulations**: In some cases, using open source models can help ensure compliance with regulations, such as GDPR, HIPAA, or CCPA, which require transparency and explainability in AI systems.\\n\\nOverall, using open source models can provide numerous benefits, from cost savings to improved transparency and customizability. By leveraging open source models, you can accelerate your AI and machine learning projects while ensuring that your systems are transparent, explainable, and fair.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"input_text = \"Why should I use open source models?\"\n",
|
||||
"\n",
|
||||
"completion = llm.invoke(input_text)\n",
|
||||
"completion"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using open source models can have numerous benefits. Here are some reasons why you should consider using open source models:\n",
|
||||
"\n",
|
||||
"1. **Cost-effective**: Open source models are often free to use, modify, and distribute, which can significantly reduce costs compared to proprietary models.\n",
|
||||
"2. **Customizability**: Open source models can be modified to fit your specific needs, allowing you to tailor the model to your project's requirements.\n",
|
||||
"3. **Transparency**: Open source models provide complete transparency into the model's architecture, training data, and algorithms, which can be essential for understanding how the model works and identifying potential biases.\n",
|
||||
"4. **Community involvement**: Open source models are often maintained by a community of developers, researchers, and users, which can lead to faster bug fixes, new feature additions, and improved performance.\n",
|
||||
"5. **Flexibility**: Open source models can be used in a variety of applications, from research to production, and can be easily integrated into different workflows and systems.\n",
|
||||
"6. **Auditability**: With open source models, you can audit the model's performance, data, and algorithms, which is critical in regulated industries or when working with sensitive data.\n",
|
||||
"7. **No vendor lock-in**: By using open source models, you're not tied to a specific vendor or proprietary technology, giving you more freedom to switch or modify your approach as needed.\n",
|
||||
"8. **Improved security**: Open source models can be reviewed and audited by the community, which can help identify and fix security vulnerabilities more quickly.\n",
|
||||
"9. **Access to cutting-edge research**: Open source models can provide access to the latest research and advancements in AI and machine learning, allowing you to leverage the work of experts in the field.\n",
|
||||
"10. **Ethical considerations**: By using open source models, you can ensure that your AI systems are transparent, explainable, and fair, which is essential for building trust in AI applications.\n",
|
||||
"11. **Reduced risk of bias**: Open source models can help reduce the risk of bias by providing transparency into the model's development, training data, and algorithms.\n",
|
||||
"12. **Faster development**: Open source models can accelerate your development process by providing pre-trained models, datasets, and tools that can be easily integrated into your project.\n",
|
||||
"13. **Improved collaboration**: Open source models can facilitate collaboration among researchers, developers, and organizations, leading to faster progress and innovation in AI and machine learning.\n",
|
||||
"14. **Access to large datasets**: Open source models can provide access to large datasets, which can be essential for training and testing AI models.\n",
|
||||
"15. **Compliance with regulations**: In some cases, using open source models can help ensure compliance with regulations, such as GDPR, HIPAA, or CCPA, which require transparency and explainability in AI systems.\n",
|
||||
"\n",
|
||||
"Overall, using open source models can provide numerous benefits, from cost savings to improved transparency and customizability. By leveraging open source models, you can accelerate your AI and machine learning projects while ensuring that your systems are transparent, explainable, and fair."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chaining"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'In German, you can say:\\n\\n\"Ich liebe das Programmieren.\"\\n\\nHere\\'s a breakdown of the sentence:\\n\\n* \"Ich\" means \"I\"\\n* \"liebe\" is the verb \"to love\" in the present tense, first person singular (I love)\\n* \"das\" is the definite article \"the\"\\n* \"Programmieren\" is the noun \"programming\"\\n\\nSo, \"Ich liebe das Programmieren\" literally means \"I love the programming\".\\n\\nIf you want to make it sound more casual, you can say:\\n\\n\"Ich liebe\\'s Programmieren.\"\\n\\nThe apostrophe in \"liebe\\'s\" is a contraction of \"liebe es\", which is a more informal way of saying \"I love it\".\\n\\nAlternatively, you can also say:\\n\\n\"Programmieren ist meine Leidenschaft.\"\\n\\nThis sentence means \"Programming is my passion\". Here, \"Programmieren\" is the subject, \"ist\" is the verb \"to be\" in the present tense, and \"meine Leidenschaft\" means \"my passion\".'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate.from_template(\"How to say {input} in {output_language}:\\n\")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"output_language\": \"German\",\n",
|
||||
" \"input\": \"I love programming.\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all `SambaStudio` llm features and configurations head to the API reference: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.sambanova.SambaStudio.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "multimodalenv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -1,318 +1,195 @@
|
||||
import json
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class SSEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for SambaStudio model endpoints.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
|
||||
def __init__(self, host_url: str, api_base_uri: str):
|
||||
"""
|
||||
Initialize the SSEndpointHandler.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
:param str api_base_uri: Base URI of the DaaS API service
|
||||
"""
|
||||
self.host_url = host_url
|
||||
self.api_base_uri = api_base_uri
|
||||
self.http_session = requests.Session()
|
||||
|
||||
def _process_response(self, response: requests.Response) -> Dict:
|
||||
"""
|
||||
Processes the API response and returns the resulting dict.
|
||||
|
||||
All resulting dicts, regardless of success or failure, will contain the
|
||||
`status_code` key with the API response status code.
|
||||
|
||||
If the API returned an error, the resulting dict will contain the key
|
||||
`detail` with the error message.
|
||||
|
||||
If the API call was successful, the resulting dict will contain the key
|
||||
`data` with the response data.
|
||||
|
||||
:param requests.Response response: the response object to process
|
||||
:return: the response dict
|
||||
:type: dict
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
result = response.json()
|
||||
except Exception as e:
|
||||
result["detail"] = str(e)
|
||||
if "status_code" not in result:
|
||||
result["status_code"] = response.status_code
|
||||
return result
|
||||
|
||||
def _process_streaming_response(
|
||||
self,
|
||||
response: requests.Response,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
if "api/predict/nlp" in self.api_base_uri:
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"could not import sseclient library"
|
||||
"Please install it with `pip install sseclient-py`."
|
||||
)
|
||||
client = sseclient.SSEClient(response)
|
||||
close_conn = False
|
||||
for event in client.events():
|
||||
if event.event == "error_event":
|
||||
close_conn = True
|
||||
chunk = {
|
||||
"event": event.event,
|
||||
"data": event.data,
|
||||
"status_code": response.status_code,
|
||||
}
|
||||
yield chunk
|
||||
if close_conn:
|
||||
client.close()
|
||||
elif (
|
||||
"api/v2/predict/generic" in self.api_base_uri
|
||||
or "api/predict/generic" in self.api_base_uri
|
||||
):
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
||||
)
|
||||
|
||||
def _get_full_url(self, path: str) -> str:
|
||||
"""
|
||||
Return the full API URL for a given path.
|
||||
|
||||
:param str path: the sub-path
|
||||
:returns: the full API URL for the sub-path
|
||||
:type: str
|
||||
"""
|
||||
return f"{self.host_url}/{self.api_base_uri}/{path}"
|
||||
|
||||
def nlp_predict(
|
||||
self,
|
||||
project: str,
|
||||
endpoint: str,
|
||||
key: str,
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
stream: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if "api/predict/nlp" in self.api_base_uri:
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
elif "api/v2/predict/generic" in self.api_base_uri:
|
||||
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
|
||||
if params:
|
||||
data = {"items": items, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"items": items}
|
||||
elif "api/predict/generic" in self.api_base_uri:
|
||||
if params:
|
||||
data = {"instances": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instances": input}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
||||
)
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(f"{project}/{endpoint}"),
|
||||
headers={"key": key},
|
||||
json=data,
|
||||
)
|
||||
return self._process_response(response)
|
||||
|
||||
def nlp_predict_stream(
|
||||
self,
|
||||
project: str,
|
||||
endpoint: str,
|
||||
key: str,
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if "api/predict/nlp" in self.api_base_uri:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
elif "api/v2/predict/generic" in self.api_base_uri:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
|
||||
if params:
|
||||
data = {"items": items, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"items": items}
|
||||
elif "api/predict/generic" in self.api_base_uri:
|
||||
if isinstance(input, list):
|
||||
input = input[0]
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
||||
)
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(f"stream/{project}/{endpoint}"),
|
||||
headers={"key": key},
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in self._process_streaming_response(response):
|
||||
yield chunk
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import Field, SecretStr
|
||||
from requests import Response
|
||||
|
||||
|
||||
class SambaStudio(LLM):
|
||||
"""
|
||||
SambaStudio large language models.
|
||||
|
||||
To use, you should have the environment variables
|
||||
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
|
||||
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
|
||||
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
|
||||
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
|
||||
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
|
||||
Setup:
|
||||
To use, you should have the environment variables
|
||||
``SAMBASTUDIO_URL`` set with your SambaStudio environment URL.
|
||||
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
|
||||
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
|
||||
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
SambaStudio(
|
||||
sambastudio_url="your-SambaStudio-environment-URL",
|
||||
sambastudio_api_key="your-SambaStudio-API-key,
|
||||
model_kwargs={
|
||||
"model" : model or expert name (set for CoE endpoints),
|
||||
"max_tokens" : max number of tokens to generate,
|
||||
"temperature" : model temperature,
|
||||
"top_p" : model top p,
|
||||
"top_k" : model top k,
|
||||
"do_sample" : wether to do sample
|
||||
"process_prompt": wether to process prompt
|
||||
(set for CoE generic v1 and v2 endpoints)
|
||||
},
|
||||
)
|
||||
Key init args — completion params:
|
||||
model: str
|
||||
The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096
|
||||
(set for CoE endpoints).
|
||||
streaming: bool
|
||||
Whether to use streaming handler when using non streaming methods
|
||||
model_kwargs: dict
|
||||
Extra Key word arguments to pass to the model:
|
||||
max_tokens: int
|
||||
max tokens to generate
|
||||
temperature: float
|
||||
model temperature
|
||||
top_p: float
|
||||
model top p
|
||||
top_k: int
|
||||
model top k
|
||||
do_sample: bool
|
||||
wether to do sample
|
||||
process_prompt:
|
||||
wether to process prompt (set for CoE generic v1 and v2 endpoints)
|
||||
Key init args — client params:
|
||||
sambastudio_url: str
|
||||
SambaStudio endpoint Url
|
||||
sambastudio_api_key: str
|
||||
SambaStudio endpoint api key
|
||||
|
||||
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
|
||||
from langchain_community.llms import SambaStudio
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
llm = SambaStudio=(
|
||||
sambastudio_url = set with your SambaStudio deployed endpoint URL,
|
||||
sambastudio_api_key = set with your SambaStudio deployed endpoint Key,
|
||||
model_kwargs = {
|
||||
"model" : model or expert name (set for CoE endpoints),
|
||||
"max_tokens" : max number of tokens to generate,
|
||||
"temperature" : model temperature,
|
||||
"top_p" : model top p,
|
||||
"top_k" : model top k,
|
||||
"do_sample" : wether to do sample
|
||||
"process_prompt" : wether to process prompt
|
||||
(set for CoE generic v1 and v2 endpoints)
|
||||
}
|
||||
)
|
||||
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
prompt = "tell me a joke"
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in llm.stream(prompt):
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
response = llm.ainvoke(prompt)
|
||||
await response
|
||||
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
SambaStudio(
|
||||
sambastudio_base_url="your-SambaStudio-environment-URL",
|
||||
sambastudio_base_uri="your-SambaStudio-base-URI",
|
||||
sambastudio_project_id="your-SambaStudio-project-ID",
|
||||
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
|
||||
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
|
||||
streaming=False
|
||||
model_kwargs={
|
||||
"do_sample": False,
|
||||
"max_tokens_to_generate": 1000,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1,
|
||||
"top_k": 50,
|
||||
#"process_prompt": False,
|
||||
#"select_expert": "Meta-Llama-3-8B-Instruct"
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
sambastudio_base_url: str = ""
|
||||
"""Base url to use"""
|
||||
sambastudio_url: str = Field(default="")
|
||||
"""SambaStudio Url"""
|
||||
|
||||
sambastudio_base_uri: str = ""
|
||||
"""endpoint base uri"""
|
||||
sambastudio_api_key: SecretStr = Field(default="")
|
||||
"""SambaStudio api key"""
|
||||
|
||||
sambastudio_project_id: str = ""
|
||||
"""Project id on sambastudio for model"""
|
||||
base_url: str = Field(default="", exclude=True)
|
||||
"""SambaStudio non streaming URL"""
|
||||
|
||||
sambastudio_endpoint_id: str = ""
|
||||
"""endpoint id on sambastudio for model"""
|
||||
streaming_url: str = Field(default="", exclude=True)
|
||||
"""SambaStudio streaming URL"""
|
||||
|
||||
sambastudio_api_key: str = ""
|
||||
"""sambastudio api key"""
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to use streaming handler when using non streaming methods"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Streaming flag to get streamed response."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"sambastudio_url": "sambastudio_url",
|
||||
"sambastudio_api_key": "sambastudio_api_key",
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_kwargs": self.model_kwargs}}
|
||||
"""Return a dictionary of identifying parameters.
|
||||
|
||||
This information is used by the LangChain callback system, which
|
||||
is used for tracing purposes make it possible to monitor LLMs.
|
||||
"""
|
||||
return {"streaming": self.streaming, **{"model_kwargs": self.model_kwargs}}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "Sambastudio LLM"
|
||||
return "sambastudio-llm"
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["sambastudio_base_url"] = get_from_dict_or_env(
|
||||
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""init and validate environment variables"""
|
||||
kwargs["sambastudio_url"] = get_from_dict_or_env(
|
||||
kwargs, "sambastudio_url", "SAMBASTUDIO_URL"
|
||||
)
|
||||
values["sambastudio_base_uri"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambastudio_base_uri",
|
||||
"SAMBASTUDIO_BASE_URI",
|
||||
default="api/predict/generic",
|
||||
)
|
||||
values["sambastudio_project_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
||||
)
|
||||
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
|
||||
)
|
||||
values["sambastudio_api_key"] = get_from_dict_or_env(
|
||||
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
||||
kwargs["sambastudio_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(kwargs, "sambastudio_api_key", "SAMBASTUDIO_API_KEY")
|
||||
)
|
||||
kwargs["base_url"], kwargs["streaming_url"] = self._get_sambastudio_urls(
|
||||
kwargs["sambastudio_url"]
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _get_sambastudio_urls(self, url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Get streaming and non streaming URLs from the given URL
|
||||
|
||||
Args:
|
||||
url: string with sambastudio base or streaming endpoint url
|
||||
|
||||
Returns:
|
||||
base_url: string with url to do non streaming calls
|
||||
streaming_url: string with url to do streaming calls
|
||||
"""
|
||||
if "openai" in url:
|
||||
base_url = url
|
||||
stream_url = url
|
||||
else:
|
||||
if "stream" in url:
|
||||
base_url = url.replace("stream/", "")
|
||||
stream_url = url
|
||||
else:
|
||||
base_url = url
|
||||
if "generic" in url:
|
||||
stream_url = "generic/stream".join(url.split("generic"))
|
||||
else:
|
||||
raise ValueError("Unsupported URL")
|
||||
return base_url, stream_url
|
||||
|
||||
def _get_tuning_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the tuning parameters to use when calling the LLM.
|
||||
|
||||
@ -321,151 +198,294 @@ class SambaStudio(LLM):
|
||||
first occurrence of any of the stop substrings.
|
||||
|
||||
Returns:
|
||||
The tuning parameters as a JSON string.
|
||||
The tuning parameters in the format required by api to use
|
||||
"""
|
||||
if stop is None:
|
||||
stop = []
|
||||
|
||||
# get the parameters to use when calling the LLM.
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _kwarg_stop_sequences
|
||||
# if not _kwarg_stop_sequences:
|
||||
# _model_kwargs["stop_sequences"] = ",".join(
|
||||
# f'"{x}"' for x in _stop_sequences
|
||||
# )
|
||||
if "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||
tuning_params_dict = _model_kwargs
|
||||
else:
|
||||
tuning_params_dict = {
|
||||
|
||||
# handle the case where stop sequences are send in the invocation
|
||||
# and stop sequences has been also set in the model parameters
|
||||
_stop_sequences = _model_kwargs.get("stop_sequences", []) + stop
|
||||
if len(_stop_sequences) > 0:
|
||||
_model_kwargs["stop_sequences"] = _stop_sequences
|
||||
|
||||
# set the parameters structure depending of the API
|
||||
if "openai" in self.sambastudio_url:
|
||||
if "select_expert" in _model_kwargs.keys():
|
||||
_model_kwargs["model"] = _model_kwargs.pop("select_expert")
|
||||
if "max_tokens_to_generate" in _model_kwargs.keys():
|
||||
_model_kwargs["max_tokens"] = _model_kwargs.pop(
|
||||
"max_tokens_to_generate"
|
||||
)
|
||||
if "process_prompt" in _model_kwargs.keys():
|
||||
_model_kwargs.pop("process_prompt")
|
||||
tuning_params = _model_kwargs
|
||||
|
||||
elif "api/v2/predict/generic" in self.sambastudio_url:
|
||||
if "model" in _model_kwargs.keys():
|
||||
_model_kwargs["select_expert"] = _model_kwargs.pop("model")
|
||||
if "max_tokens" in _model_kwargs.keys():
|
||||
_model_kwargs["max_tokens_to_generate"] = _model_kwargs.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
tuning_params = _model_kwargs
|
||||
|
||||
elif "api/predict/generic" in self.sambastudio_url:
|
||||
if "model" in _model_kwargs.keys():
|
||||
_model_kwargs["select_expert"] = _model_kwargs.pop("model")
|
||||
if "max_tokens" in _model_kwargs.keys():
|
||||
_model_kwargs["max_tokens_to_generate"] = _model_kwargs.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
|
||||
tuning_params = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
}
|
||||
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
def _handle_nlp_predict(
|
||||
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
||||
) -> str:
|
||||
"""
|
||||
Perform an NLP prediction using the SambaStudio endpoint handler.
|
||||
|
||||
Args:
|
||||
sdk: The SSEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
response = sdk.nlp_predict(
|
||||
self.sambastudio_project_id,
|
||||
self.sambastudio_endpoint_id,
|
||||
self.sambastudio_api_key,
|
||||
prompt,
|
||||
tuning_params,
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
optional_detail = response.get("detail")
|
||||
if optional_detail:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n Details: {optional_detail}"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n response {response}"
|
||||
)
|
||||
if "api/predict/nlp" in self.sambastudio_base_uri:
|
||||
return response["data"][0]["completion"]
|
||||
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||
return response["items"][0]["value"]["completion"]
|
||||
elif "api/predict/generic" in self.sambastudio_base_uri:
|
||||
return response["predictions"][0]["completion"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.sambastudio_base_uri} not implemented"
|
||||
f"Unsupported URL{self.sambastudio_url}"
|
||||
"only openai, generic v1 and generic v2 APIs are supported"
|
||||
)
|
||||
|
||||
def _handle_completion_requests(
|
||||
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
||||
) -> str:
|
||||
return tuning_params
|
||||
|
||||
def _handle_request(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
streaming: Optional[bool] = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Perform a prediction using the SambaStudio endpoint handler.
|
||||
Performs a post request to the LLM API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for the prediction.
|
||||
stop: stop sequences.
|
||||
prompt: The prompt to pass into the model
|
||||
stop: list of stop tokens
|
||||
streaming: wether to do a streaming call
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
A request Response object
|
||||
"""
|
||||
ss_endpoint = SSEndpointHandler(
|
||||
self.sambastudio_base_url, self.sambastudio_base_uri
|
||||
)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||
|
||||
def _handle_nlp_predict_stream(
|
||||
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
params = self._get_tuning_params(stop)
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
for chunk in sdk.nlp_predict_stream(
|
||||
self.sambastudio_project_id,
|
||||
self.sambastudio_endpoint_id,
|
||||
self.sambastudio_api_key,
|
||||
prompt,
|
||||
tuning_params,
|
||||
):
|
||||
if chunk["status_code"] != 200:
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
# create request payload for openAI v1 API
|
||||
if "openai" in self.sambastudio_url:
|
||||
messages_dict = [{"role": "user", "content": prompt[0]}]
|
||||
data = {"messages": messages_dict, "stream": streaming, **params}
|
||||
data = {key: value for key, value in data.items() if value is not None}
|
||||
headers = {
|
||||
"Authorization": f"Bearer "
|
||||
f"{self.sambastudio_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# create request payload for generic v1 API
|
||||
elif "api/v2/predict/generic" in self.sambastudio_url:
|
||||
if params.get("process_prompt", False):
|
||||
prompt = json.dumps(
|
||||
{
|
||||
"conversation_id": "sambaverse-conversation-id",
|
||||
"messages": [
|
||||
{"message_id": None, "role": "user", "content": prompt[0]}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
prompt = prompt[0]
|
||||
items = [{"id": "item0", "value": prompt}]
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
data = {"items": items, "params": params}
|
||||
headers = {"key": self.sambastudio_api_key.get_secret_value()}
|
||||
|
||||
# create request payload for generic v1 API
|
||||
elif "api/predict/generic" in self.sambastudio_url:
|
||||
if params.get("process_prompt", False):
|
||||
if params["process_prompt"].get("value") == "True":
|
||||
prompt = json.dumps(
|
||||
{
|
||||
"conversation_id": "sambaverse-conversation-id",
|
||||
"messages": [
|
||||
{
|
||||
"message_id": None,
|
||||
"role": "user",
|
||||
"content": prompt[0],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
prompt = prompt[0]
|
||||
else:
|
||||
prompt = prompt[0]
|
||||
if streaming:
|
||||
data = {"instance": prompt, "params": params}
|
||||
else:
|
||||
data = {"instances": [prompt], "params": params}
|
||||
headers = {"key": self.sambastudio_api_key.get_secret_value()}
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL{self.sambastudio_url}"
|
||||
"only openai, generic v1 and generic v2 APIs are supported"
|
||||
)
|
||||
|
||||
# make the request to SambaStudio API
|
||||
http_session = requests.Session()
|
||||
if streaming:
|
||||
response = http_session.post(
|
||||
self.streaming_url, headers=headers, json=data, stream=True
|
||||
)
|
||||
else:
|
||||
response = http_session.post(
|
||||
self.base_url, headers=headers, json=data, stream=False
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova / complete call failed with status code "
|
||||
f"{response.status_code}."
|
||||
f"{response.text}."
|
||||
)
|
||||
return response
|
||||
|
||||
def _process_response(self, response: Response) -> str:
|
||||
"""
|
||||
Process a non streaming response from the api
|
||||
|
||||
Args:
|
||||
response: A request Response object
|
||||
|
||||
Returns
|
||||
completion: a string with model generation
|
||||
"""
|
||||
|
||||
# Extract json payload form response
|
||||
try:
|
||||
response_dict = response.json()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed couldn't get JSON response {e}"
|
||||
f"response: {response.text}"
|
||||
)
|
||||
|
||||
# process response payload for openai compatible API
|
||||
if "openai" in self.sambastudio_url:
|
||||
completion = response_dict["choices"][0]["message"]["content"]
|
||||
# process response payload for generic v2 API
|
||||
elif "api/v2/predict/generic" in self.sambastudio_url:
|
||||
completion = response_dict["items"][0]["value"]["completion"]
|
||||
# process response payload for generic v1 API
|
||||
elif "api/predict/generic" in self.sambastudio_url:
|
||||
completion = response_dict["predictions"][0]["completion"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL{self.sambastudio_url}"
|
||||
"only openai, generic v1 and generic v2 APIs are supported"
|
||||
)
|
||||
return completion
|
||||
|
||||
def _process_stream_response(self, response: Response) -> Iterator[GenerationChunk]:
|
||||
"""
|
||||
Process a streaming response from the api
|
||||
|
||||
Args:
|
||||
response: An iterable request Response object
|
||||
|
||||
Yields:
|
||||
GenerationChunk: a GenerationChunk with model partial generation
|
||||
"""
|
||||
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"could not import sseclient library"
|
||||
"Please install it with `pip install sseclient-py`."
|
||||
)
|
||||
|
||||
# process response payload for openai compatible API
|
||||
if "openai" in self.sambastudio_url:
|
||||
client = sseclient.SSEClient(response)
|
||||
for event in client.events():
|
||||
if event.event == "error_event":
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
f"{response.status_code}."
|
||||
f"{event.data}."
|
||||
)
|
||||
if "api/predict/nlp" in self.sambastudio_base_uri:
|
||||
text = json.loads(chunk["data"])["stream_token"]
|
||||
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||
text = chunk["result"]["items"][0]["value"]["stream_token"]
|
||||
elif "api/predict/generic" in self.sambastudio_base_uri:
|
||||
if len(chunk["result"]["responses"]) > 0:
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
else:
|
||||
text = ""
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.sambastudio_base_uri}"
|
||||
f"not implemented"
|
||||
)
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
try:
|
||||
# check if the response is not a final event ("[DONE]")
|
||||
if event.data != "[DONE]":
|
||||
if isinstance(event.data, str):
|
||||
data = json.loads(event.data)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}."
|
||||
f"{event.data}."
|
||||
)
|
||||
if data.get("error"):
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}."
|
||||
f"{event.data}."
|
||||
)
|
||||
if len(data["choices"]) > 0:
|
||||
content = data["choices"][0]["delta"]["content"]
|
||||
else:
|
||||
content = ""
|
||||
generated_chunk = GenerationChunk(text=content)
|
||||
yield generated_chunk
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error getting content chunk raw streamed response: {e}"
|
||||
f"data: {event.data}"
|
||||
)
|
||||
|
||||
# process response payload for generic v2 API
|
||||
elif "api/v2/predict/generic" in self.sambastudio_url:
|
||||
for line in response.iter_lines():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
content = data["result"]["items"][0]["value"]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=content)
|
||||
yield generated_chunk
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error getting content chunk raw streamed response: {e}"
|
||||
f"line: {line}"
|
||||
)
|
||||
|
||||
# process response payload for generic v1 API
|
||||
elif "api/predict/generic" in self.sambastudio_url:
|
||||
for line in response.iter_lines():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
content = data["result"]["responses"][0]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=content)
|
||||
yield generated_chunk
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error getting content chunk raw streamed response: {e}"
|
||||
f"line: {line}"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL{self.sambastudio_url}"
|
||||
"only openai, generic v1 and generic v2 APIs are supported"
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -478,56 +498,16 @@ class SambaStudio(LLM):
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
stop: a list of strings on which the model should stop generating.
|
||||
run_manager: A run manager with callbacks for the LLM.
|
||||
Yields:
|
||||
chunk: GenerationChunk with model partial generation
|
||||
"""
|
||||
ss_endpoint = SSEndpointHandler(
|
||||
self.sambastudio_base_url, self.sambastudio_base_uri
|
||||
)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in self._handle_nlp_predict_stream(
|
||||
ss_endpoint, prompt, tuning_params
|
||||
):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
def _handle_stream_request(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]],
|
||||
run_manager: Optional[CallbackManagerForLLMRun],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambastudio model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
response = self._handle_request(prompt, stop, streaming=True)
|
||||
for chunk in self._process_stream_response(response):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -540,17 +520,20 @@ class SambaStudio(LLM):
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
stop: a list of strings on which the model should stop generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
result: string with model generation
|
||||
"""
|
||||
if stop is not None:
|
||||
raise Exception("stop not implemented")
|
||||
try:
|
||||
if self.streaming:
|
||||
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
||||
return self._handle_completion_requests(prompt, stop)
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
|
||||
return completion
|
||||
|
||||
response = self._handle_request(prompt, stop, streaming=False)
|
||||
completion = self._process_response(response)
|
||||
return completion
|
||||
|
@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
|
||||
# PRs that increase the current count will not be accepted.
|
||||
# PRs that decrease update the code in the repository
|
||||
# and allow decreasing the count of are welcome!
|
||||
current_count=127
|
||||
current_count=126
|
||||
|
||||
if [ "$count" -gt "$current_count" ]; then
|
||||
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."
|
||||
@ -52,4 +52,4 @@ if [ "$count" -gt "$current_count" ]; then
|
||||
elif [ "$count" -lt "$current_count" ]; then
|
||||
echo "Please update the $current_count variable in ./scripts/check_pydantic.sh to $count"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
Loading…
Reference in New Issue
Block a user