community[patch]: sambanova llm integration improvement (#23137)

- **Description:** sambanova sambaverse integration improvement: removed
input parsing that was changing raw user input, and was making to use
process prompt parameter as true mandatory
pull/22546/head
Jorge Piedrahita Ortiz 2 weeks ago committed by GitHub
parent e162893d7f
commit b3e53ffca0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -87,7 +87,6 @@
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"process_prompt\": True,\n",
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
" # \"repetition_penalty\": 1.0,\n",
@ -116,7 +115,6 @@
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"process_prompt\": True,\n",
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
" # \"repetition_penalty\": 1.0,\n",
@ -177,14 +175,16 @@
"import os\n",
"\n",
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
"# sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
"sambastudio_base_uri = (\n",
" \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
")\n",
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
"\n",
"# Set the environment variables\n",
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
"# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
"os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
@ -247,6 +247,40 @@
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also call a CoE endpoint expert model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Using a CoE endpoint\n",
"\n",
"from langchain_community.llms.sambanova import SambaStudio\n",
"\n",
"llm = SambaStudio(\n",
" streaming=False,\n",
" model_kwargs={\n",
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",
"print(llm.invoke(\"Why should I use open source models?\"))"
]
}
],
"metadata": {

@ -43,7 +43,7 @@ class SVEndpointHandler:
:param requests.Response response: the response object to process
:return: the response dict
:rtype: dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
@ -87,7 +87,7 @@ class SVEndpointHandler:
"""
Return the full API URL for a given path.
:returns: the full API URL for the sub-path
:rtype: str
:type: str
"""
return f"{self.host_url}{self.API_BASE_PATH}"
@ -108,23 +108,12 @@ class SVEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
:type: dict
"""
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
data = {"instance": input}
response = self.http_session.post(
self._get_full_url(),
headers={
@ -152,23 +141,12 @@ class SVEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
:type: dict
"""
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
data = {"instance": input}
# Streaming output
response = self.http_session.post(
self._get_full_url(),
@ -522,7 +500,7 @@ class SSEndpointHandler:
:param requests.Response response: the response object to process
:return: the response dict
:rtype: dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
@ -581,7 +559,7 @@ class SSEndpointHandler:
:param str path: the sub-path
:returns: the full API URL for the sub-path
:rtype: str
:type: str
"""
return f"{self.host_url}/{self.api_base_uri}/{path}"
@ -603,7 +581,7 @@ class SSEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
if isinstance(input, str):
input = [input]
@ -645,7 +623,7 @@ class SSEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
if "nlp" in self.api_base_uri:
if isinstance(input, str):

Loading…
Cancel
Save