From b3e53ffca03f1e8052e935e0026a8d4a41134402 Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Wed, 19 Jun 2024 12:30:14 -0500 Subject: [PATCH] community[patch]: sambanova llm integration improvement (#23137) - **Description:** sambanova sambaverse integration improvement: removed input parsing that was changing raw user input, and was making to use process prompt parameter as true mandatory --- docs/docs/integrations/llms/sambanova.ipynb | 42 ++++++++++++++-- .../langchain_community/llms/sambanova.py | 50 ++++++------------- 2 files changed, 52 insertions(+), 40 deletions(-) diff --git a/docs/docs/integrations/llms/sambanova.ipynb b/docs/docs/integrations/llms/sambanova.ipynb index 522b9bb959..64ed346066 100644 --- a/docs/docs/integrations/llms/sambanova.ipynb +++ b/docs/docs/integrations/llms/sambanova.ipynb @@ -87,7 +87,6 @@ " \"do_sample\": True,\n", " \"max_tokens_to_generate\": 1000,\n", " \"temperature\": 0.01,\n", - " \"process_prompt\": True,\n", " \"select_expert\": \"llama-2-7b-chat-hf\",\n", " # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n", " # \"repetition_penalty\": 1.0,\n", @@ -116,7 +115,6 @@ " \"do_sample\": True,\n", " \"max_tokens_to_generate\": 1000,\n", " \"temperature\": 0.01,\n", - " \"process_prompt\": True,\n", " \"select_expert\": \"llama-2-7b-chat-hf\",\n", " # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n", " # \"repetition_penalty\": 1.0,\n", @@ -177,14 +175,16 @@ "import os\n", "\n", "sambastudio_base_url = \"\"\n", - "# sambastudio_base_uri = \"\" # optional, \"api/predict/nlp\" set as default\n", + "sambastudio_base_uri = (\n", + " \"\" # optional, \"api/predict/nlp\" set as default\n", + ")\n", "sambastudio_project_id = \"\"\n", "sambastudio_endpoint_id = \"\"\n", "sambastudio_api_key = \"\"\n", "\n", "# Set the environment variables\n", "os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n", - "# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n", + "os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n", "os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n", "os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", "os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key" @@ -247,6 +247,40 @@ "for chunk in llm.stream(\"Why should I use open source models?\"):\n", " print(chunk, end=\"\", flush=True)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also call a CoE endpoint expert model " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Using a CoE endpoint\n", + "\n", + "from langchain_community.llms.sambanova import SambaStudio\n", + "\n", + "llm = SambaStudio(\n", + " streaming=False,\n", + " model_kwargs={\n", + " \"do_sample\": True,\n", + " \"max_tokens_to_generate\": 1000,\n", + " \"temperature\": 0.01,\n", + " \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n", + " # \"repetition_penalty\": 1.0,\n", + " # \"top_k\": 50,\n", + " # \"top_logprobs\": 0,\n", + " # \"top_p\": 1.0\n", + " },\n", + ")\n", + "\n", + "print(llm.invoke(\"Why should I use open source models?\"))" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py index e7e89e1522..6017f9b8b3 100644 --- a/libs/community/langchain_community/llms/sambanova.py +++ b/libs/community/langchain_community/llms/sambanova.py @@ -43,7 +43,7 @@ class SVEndpointHandler: :param requests.Response response: the response object to process :return: the response dict - :rtype: dict + :type: dict """ result: Dict[str, Any] = {} try: @@ -87,7 +87,7 @@ class SVEndpointHandler: """ Return the full API URL for a given path. :returns: the full API URL for the sub-path - :rtype: str + :type: str """ return f"{self.host_url}{self.API_BASE_PATH}" @@ -108,23 +108,12 @@ class SVEndpointHandler: :param str input_str: Input string :param str params: Input params string :returns: Prediction results - :rtype: dict - """ - parsed_element = { - "conversation_id": "sambaverse-conversation-id", - "messages": [ - { - "message_id": 0, - "role": "user", - "content": input, - } - ], - } - parsed_input = json.dumps(parsed_element) + :type: dict + """ if params: - data = {"instance": parsed_input, "params": json.loads(params)} + data = {"instance": input, "params": json.loads(params)} else: - data = {"instance": parsed_input} + data = {"instance": input} response = self.http_session.post( self._get_full_url(), headers={ @@ -152,23 +141,12 @@ class SVEndpointHandler: :param str input_str: Input string :param str params: Input params string :returns: Prediction results - :rtype: dict - """ - parsed_element = { - "conversation_id": "sambaverse-conversation-id", - "messages": [ - { - "message_id": 0, - "role": "user", - "content": input, - } - ], - } - parsed_input = json.dumps(parsed_element) + :type: dict + """ if params: - data = {"instance": parsed_input, "params": json.loads(params)} + data = {"instance": input, "params": json.loads(params)} else: - data = {"instance": parsed_input} + data = {"instance": input} # Streaming output response = self.http_session.post( self._get_full_url(), @@ -522,7 +500,7 @@ class SSEndpointHandler: :param requests.Response response: the response object to process :return: the response dict - :rtype: dict + :type: dict """ result: Dict[str, Any] = {} try: @@ -581,7 +559,7 @@ class SSEndpointHandler: :param str path: the sub-path :returns: the full API URL for the sub-path - :rtype: str + :type: str """ return f"{self.host_url}/{self.api_base_uri}/{path}" @@ -603,7 +581,7 @@ class SSEndpointHandler: :param str input_str: Input string :param str params: Input params string :returns: Prediction results - :rtype: dict + :type: dict """ if isinstance(input, str): input = [input] @@ -645,7 +623,7 @@ class SSEndpointHandler: :param str input_str: Input string :param str params: Input params string :returns: Prediction results - :rtype: dict + :type: dict """ if "nlp" in self.api_base_uri: if isinstance(input, str):