From 852722ea45212936e6d10e775610af82bbd1a193 Mon Sep 17 00:00:00 2001 From: Toshish Jawale <986859+toshish@users.noreply.github.com> Date: Tue, 15 Aug 2023 15:33:07 -0700 Subject: [PATCH] Improvements in Nebula LLM (#9226) - Description: Added improvements in Nebula LLM to perform auto-retry; more generation parameters supported. Conversation is no longer required to be passed in the LLM object. Examples are updated. - Issue: N/A - Dependencies: N/A - Tag maintainer: @baskaryan - Twitter handle: symbldotai --------- Co-authored-by: toshishjawale --- .../integrations/llms/symblai_nebula.ipynb | 89 ++++---- .../integrations/providers/symblai_nebula.mdx | 12 +- .../langchain/llms/symblai_nebula.py | 198 +++++++++++------- .../llms/test_symblai_nebula.py | 80 ++++--- 4 files changed, 214 insertions(+), 165 deletions(-) diff --git a/docs/extras/integrations/llms/symblai_nebula.ipynb b/docs/extras/integrations/llms/symblai_nebula.ipynb index 1ca58697e0..304917a5cf 100644 --- a/docs/extras/integrations/llms/symblai_nebula.ipynb +++ b/docs/extras/integrations/llms/symblai_nebula.ipynb @@ -1,80 +1,83 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", - "id": "9597802c", - "metadata": {}, "source": [ - "# Nebula\n", + "# Nebula (Symbl.ai)\n", + "[Nebula](https://symbl.ai/nebula/) is a large language model (LLM) built by [Symbl.ai](https://symbl.ai). It is trained to perform generative tasks on human conversations. Nebula excels at modeling the nuanced details of a conversation and performing tasks on the conversation.\n", "\n", - "[Nebula](https://symbl.ai/nebula/) is a fully-managed Conversation platform, on which you can build, deploy, and manage scalable AI applications.\n", + "Nebula documentation: https://docs.symbl.ai/docs/nebula-llm\n", "\n", - "This example goes over how to use LangChain to interact with the [Nebula platform](https://docs.symbl.ai/docs/nebula-llm-overview). \n", - "\n", - "It will send the requests to Nebula Service endpoint, which concatenates `SYMBLAI_NEBULA_SERVICE_URL` and `SYMBLAI_NEBULA_SERVICE_PATH`, with a token defined in `SYMBLAI_NEBULA_SERVICE_TOKEN`" - ] + "This example goes over how to use LangChain to interact with the [Nebula platform](https://docs.symbl.ai/docs/nebula-llm)." + ], + "metadata": { + "collapsed": false + }, + "id": "bb8cd830db4a004e" }, { "cell_type": "markdown", - "id": "f15ebe0d", - "metadata": {}, "source": [ - "### Integrate with a LLMChain" - ] + "Make sure you have API Key with you. If you don't have one please [request one](https://info.symbl.ai/Nebula_Private_Beta.html)." + ], + "metadata": { + "collapsed": false + }, + "id": "519570b6539aa18c" }, { "cell_type": "code", "execution_count": null, - "id": "5472a7cd-af26-48ca-ae9b-5f6ae73c74d2", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ - "import os\n", + "from langchain.llms.symblai_nebula import Nebula\n", "\n", - "os.environ[\"NEBULA_SERVICE_URL\"] = NEBULA_SERVICE_URL\n", - "os.environ[\"NEBULA_SERVICE_PATH\"] = NEBULA_SERVICE_PATH\n", - "os.environ[\"NEBULA_SERVICE_API_KEY\"] = NEBULA_SERVICE_API_KEY" - ] + "llm = Nebula(nebula_api_key='')" + ], + "metadata": { + "collapsed": false + }, + "id": "9f47bef45880aece" }, { - "cell_type": "code", - "execution_count": null, - "id": "6fb585dd", + "cell_type": "markdown", + "source": [ + "Use a conversation transcript and instruction to construct a prompt." + ], "metadata": { - "tags": [] + "collapsed": false }, - "outputs": [], - "source": [ - "from langchain.llms import OpenLLM\n", - "\n", - "llm = OpenLLM(\n", - " conversation=\"\",\n", - ")" - ] + "id": "88c6a516ef51c74b" }, { "cell_type": "code", "execution_count": null, - "id": "035dea0f", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ "from langchain import PromptTemplate, LLMChain\n", "\n", - "template = \"Identify the {count} main objectives or goals mentioned in this context concisely in less points. Emphasize on key intents.\"\n", + "conversation = \"\"\"Sam: Good morning, team! Let's keep this standup concise. We'll go in the usual order: what you did yesterday, what you plan to do today, and any blockers. Alex, kick us off.\n", + "Alex: Morning! Yesterday, I wrapped up the UI for the user dashboard. The new charts and widgets are now responsive. I also had a sync with the design team to ensure the final touchups are in line with the brand guidelines. Today, I'll start integrating the frontend with the new API endpoints Rhea was working on. The only blocker is waiting for some final API documentation, but I guess Rhea can update on that.\n", + "Rhea: Hey, all! Yep, about the API documentation - I completed the majority of the backend work for user data retrieval yesterday. The endpoints are mostly set up, but I need to do a bit more testing today. I'll finalize the API documentation by noon, so that should unblock Alex. After that, I’ll be working on optimizing the database queries for faster data fetching. No other blockers on my end.\n", + "Sam: Great, thanks Rhea. Do reach out if you need any testing assistance or if there are any hitches with the database. Now, my update: Yesterday, I coordinated with the client to get clarity on some feature requirements. Today, I'll be updating our project roadmap and timelines based on their feedback. Additionally, I'll be sitting with the QA team in the afternoon for preliminary testing. Blocker: I might need both of you to be available for a quick call in case the client wants to discuss the changes live.\n", + "Alex: Sounds good, Sam. Just let us know a little in advance for the call.\n", + "Rhea: Agreed. We can make time for that.\n", + "Sam: Perfect! Let's keep the momentum going. Reach out if there are any sudden issues or support needed. Have a productive day!\n", + "Alex: You too.\n", + "Rhea: Thanks, bye!\"\"\"\n", + "\n", + "instruction = \"Identify the main objectives mentioned in this conversation.\"\n", "\n", - "prompt = PromptTemplate(template=template, input_variables=[\"count\"])\n", + "prompt = PromptTemplate.from_template(\"{instruction}\\n{conversation}\")\n", "\n", "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", "\n", - "generated = llm_chain.run(count=\"five\")\n", - "print(generated)" - ] + "llm_chain.run(instruction=instruction, conversation=conversation)" + ], + "metadata": { + "collapsed": false + }, + "id": "5977ccc2d4432624" } ], "metadata": { diff --git a/docs/extras/integrations/providers/symblai_nebula.mdx b/docs/extras/integrations/providers/symblai_nebula.mdx index 24ecd76a0e..b716af6ff0 100644 --- a/docs/extras/integrations/providers/symblai_nebula.mdx +++ b/docs/extras/integrations/providers/symblai_nebula.mdx @@ -5,16 +5,14 @@ It is broken into two parts: installation and setup, and then references to spec ## Installation and Setup -- Get an Nebula API Key and set as environment variables (`SYMBLAI_NEBULA_SERVICE_URL`, `SYMBLAI_NEBULA_SERVICE_PATH`, `SYMBLAI_NEBULA_SERVICE_TOKEN`) - - Sign up for a FREE Symbl.ai/Nebula Account: [https://nebula.symbl.ai/playground/](https://nebula.symbl.ai/playground/) -- Please see the [Nebula documentation](https://docs.symbl.ai/docs/nebula-llm-overview) for more details. - - No time? Visit the [Nebula Quickstart Guide](https://docs.symbl.ai/docs/nebula-quickstart). - -## Wrappers +- Get an [Nebula API Key](https://info.symbl.ai/Nebula_Private_Beta.html) and set as environment variable `NEBULA_API_KEY` +- Please see the [Nebula documentation](https://docs.symbl.ai/docs/nebula-llm) for more details. +- No time? Visit the [Nebula Quickstart Guide](https://docs.symbl.ai/docs/nebula-quickstart). ### LLM -There exists an Nebula LLM wrapper, which you can access with +There exists an Nebula LLM wrapper, which you can access with ```python from langchain.llms import Nebula +llm = Nebula() ``` diff --git a/libs/langchain/langchain/llms/symblai_nebula.py b/libs/langchain/langchain/llms/symblai_nebula.py index 892269abc3..e52d53dc71 100644 --- a/libs/langchain/langchain/llms/symblai_nebula.py +++ b/libs/langchain/langchain/llms/symblai_nebula.py @@ -1,8 +1,17 @@ +import json import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional import requests -from pydantic_v1 import Extra, root_validator +from pydantic import Extra, root_validator +from requests import ConnectTimeout, ReadTimeout, RequestException +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -19,7 +28,7 @@ class Nebula(LLM): """Nebula Service models. To use, you should have the environment variable ``NEBULA_SERVICE_URL``, - ``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula + ``NEBULA_SERVICE_PATH`` and ``NEBULA_API_KEY`` set with your Nebula Service, or pass it as a named parameter to the constructor. Example: @@ -28,9 +37,9 @@ class Nebula(LLM): from langchain.llms import Nebula nebula = Nebula( - nebula_service_url="SERVICE_URL", - nebula_service_path="SERVICE_ROUTE", - nebula_api_key="SERVICE_TOKEN", + nebula_service_url="NEBULA_SERVICE_URL", + nebula_service_path="NEBULA_SERVICE_PATH", + nebula_api_key="NEBULA_API_KEY", ) """ # noqa: E501 @@ -38,14 +47,19 @@ class Nebula(LLM): model_kwargs: Optional[dict] = None """Optional""" + nebula_service_url: Optional[str] = None nebula_service_path: Optional[str] = None nebula_api_key: Optional[str] = None - conversation: str = "" - return_scores: Optional[str] = "false" - max_new_tokens: Optional[int] = 2048 - top_k: Optional[float] = 2 - penalty_alpha: Optional[float] = 0.1 + model: Optional[str] = None + max_new_tokens: Optional[int] = 128 + temperature: Optional[float] = 0.6 + top_p: Optional[float] = 0.95 + repetition_penalty: Optional[float] = 1.0 + top_k: Optional[int] = 0 + penalty_alpha: Optional[float] = 0.0 + stop_sequences: Optional[List[str]] = None + max_retries: Optional[int] = 10 class Config: """Configuration for this pydantic object.""" @@ -68,7 +82,7 @@ class Nebula(LLM): DEFAULT_NEBULA_SERVICE_PATH, ) nebula_api_key = get_from_dict_or_env( - values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", "" + values, "nebula_api_key", "NEBULA_API_KEY", None ) if nebula_service_url.endswith("/"): @@ -76,25 +90,24 @@ class Nebula(LLM): if not nebula_service_path.startswith("/"): nebula_service_path = "/" + nebula_service_path - """ TODO: Future login""" - """ - try: - nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}" - headers = { - "Content-Type": "application/json", - "ApiKey": "{nebula_api_key}", - } - requests.get(nebula_service_endpoint, headers=headers) - except requests.exceptions.RequestException as e: - raise ValueError(e) - """ - values["nebula_service_url"] = nebula_service_url values["nebula_service_path"] = nebula_service_path values["nebula_api_key"] = nebula_api_key return values + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Cohere API.""" + return { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "repetition_penalty": self.repetition_penalty, + "penalty_alpha": self.penalty_alpha, + } + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" @@ -103,7 +116,6 @@ class Nebula(LLM): "nebula_service_url": self.nebula_service_url, "nebula_service_path": self.nebula_service_path, **{"model_kwargs": _model_kwargs}, - "conversation": self.conversation, } @property @@ -111,6 +123,25 @@ class Nebula(LLM): """Return type of llm.""" return "nebula" + def _invocation_params( + self, stop_sequences: Optional[List[str]], **kwargs: Any + ) -> dict: + params = self._default_params + if self.stop_sequences is not None and stop_sequences is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop_sequences is not None: + params["stop_sequences"] = self.stop_sequences + else: + params["stop_sequences"] = stop_sequences + return {**params, **kwargs} + + @staticmethod + def _process_response(response: Any, stop: Optional[List[str]]) -> str: + text = response["output"]["text"] + if stop: + text = enforce_stop_tokens(text, stop) + return text + def _call( self, prompt: str, @@ -128,57 +159,84 @@ class Nebula(LLM): .. code-block:: python response = nebula("Tell me a joke.") """ - - _model_kwargs = self.model_kwargs or {} - - nebula_service_endpoint = f"{self.nebula_service_url}{self.nebula_service_path}" - - headers = { - "Content-Type": "application/json", - "ApiKey": f"{self.nebula_api_key}", + params = self._invocation_params(stop, **kwargs) + prompt = prompt.strip() + if "\n" in prompt: + instruction = prompt.split("\n")[0] + conversation = "\n".join(prompt.split("\n")[1:]) + else: + raise ValueError("Prompt must contain instruction and conversation.") + + response = completion_with_retry( + self, + instruction=instruction, + conversation=conversation, + params=params, + url=f"{self.nebula_service_url}{self.nebula_service_path}", + ) + _stop = params.get("stop_sequences") + return self._process_response(response, _stop) + + +def make_request( + self: Nebula, + instruction: str, + conversation: str, + url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}", + params: Dict = {}, +) -> Any: + """Generate text from the model.""" + headers = { + "Content-Type": "application/json", + "ApiKey": f"{self.nebula_api_key}", + } + + body = { + "prompt": { + "instruction": instruction, + "conversation": {"text": f"{conversation}"}, } + } - body = { - "prompt": { - "instruction": prompt, - "conversation": {"text": f"{self.conversation}"}, - }, - "return_scores": self.return_scores, - "max_new_tokens": self.max_new_tokens, - "top_k": self.top_k, - "penalty_alpha": self.penalty_alpha, - } + # add params to body + for key, value in params.items(): + body[key] = value - if len(self.conversation) == 0: - raise ValueError("Error conversation is empty.") + # make request + response = requests.post(url, headers=headers, json=body) - logger.debug(f"NEBULA _model_kwargs: {_model_kwargs}") - logger.debug(f"NEBULA body: {body}") - logger.debug(f"NEBULA kwargs: {kwargs}") - logger.debug(f"NEBULA conversation: {self.conversation}") + if response.status_code != 200: + raise Exception( + f"Request failed with status code {response.status_code}" + f" and message {response.text}" + ) - # call API - try: - response = requests.post( - nebula_service_endpoint, headers=headers, json=body - ) - except requests.exceptions.RequestException as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + return json.loads(response.text) - logger.debug(f"NEBULA response: {response}") - if response.status_code != 200: - raise ValueError( - f"Error returned by service, status code {response.status_code}" - ) +def _create_retry_decorator(llm: Nebula) -> Callable[[Any], Any]: + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterward + max_retries = llm.max_retries if llm.max_retries is not None else 3 + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type((RequestException, ConnectTimeout, ReadTimeout)) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) - """ get the result """ - text = response.text - """ enforce stop """ - if stop is not None: - # This is required since the stop tokens - # are not enforced by the model parameters - text = enforce_stop_tokens(text, stop) +def completion_with_retry(llm: Nebula, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) - return text + @retry_decorator + def _completion_with_retry(**_kwargs: Any) -> Any: + return make_request(llm, **_kwargs) + + return _completion_with_retry(**kwargs) diff --git a/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py b/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py index 618ad4ed0d..35ebd96838 100644 --- a/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py +++ b/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py @@ -1,56 +1,46 @@ """Test Nebula API wrapper.""" - from langchain import LLMChain, PromptTemplate from langchain.llms.symblai_nebula import Nebula def test_symblai_nebula_call() -> None: """Test valid call to Nebula.""" - conversation = """Speaker 1: Thank you for calling ABC, company.Speaker 1: My name -is Mary.Speaker 1: How may I help you?Speaker 2: Today?Speaker 1: All right, -Madam.Speaker 1: I really apologize for this inconvenient.Speaker 1: I will be happy -to assist you in this matter.Speaker 1: Could you please offer me Yuri your account -number?Speaker 1: Alright Madam, thank you very much.Speaker 1: Let me check that -for confirmation.Speaker 1: Did you say 534 00 365?Speaker 2: 48?Speaker 1: Very good -man.Speaker 1: Now for verification purposes, can I please get your full?Speaker -2: Name?Speaker 1: Alright, thank you.Speaker 1: Very much.Speaker 1: Madam.Speaker -1: Can I, please get your birthdate now?Speaker 1: I am sorry madam.Speaker 1: I -didn't make this clear is for verification.Speaker 1: Purposes is the company -request.Speaker 1: The system requires me, your name, your complete name and your -date of.Speaker 2: Birth.Speaker 2: Alright, thank you very much madam.Speaker 1: -All right.Speaker 1: Thank you very much, Madam.Speaker 1: Thank you for that -information.Speaker 1: Let me check what happens.Speaker 2: Here.Speaker 1: So -according to our data space them, you did pay your last bill last August the 12, -which was two days ago in one of our Affiliated payment centers.Speaker 1: So, at the -moment you currently, We have zero balance.Speaker 1: So however, the bill that you -received was generated a week before you made the pavement, this is reason why you -already make this payment, have not been reflected yet.Speaker 1: So what we do in -this case, you just simply disregard the amount indicated in the field and you -continue to enjoy our service man.Speaker 1: Sure, Madam.Speaker 1: And I am sure -you need your cell phone for everything for life, right?Speaker 1: So I really -apologize for this inconvenience.Speaker 1: And let me tell you that delays in the -bill is usually caused by delays in our Courier Service.Speaker 1: That is to say -that it'''s a problem, not with the company, but with a courier service, For a more -updated, feel of your account, you can visit our website and log into your account, -and they'''re in the system.Speaker 1: On the website, you are going to have the -possibility to pay the bill.Speaker 1: That is more.Speaker 2: Updated.Speaker 2: -Of course, Madam I can definitely assist you with that.Speaker 2: Once you have, -you want to see your bill updated, please go to www.hsn BC campus, any.com after -that.Speaker 2: You will see in the tale.Speaker 1: All right corner.Speaker 1: So -you're going to see a pay now button.Speaker 1: Please click on the pay now button -and the serve.Speaker 1: The system is going to ask you for personal -information.Speaker 1: Such as your first name, your ID account, your the number of -your account, your email address, and your phone number once you complete this personal -information.""" - llm = Nebula( - conversation=conversation, - ) + conversation = """Sam: Good morning, team! Let's keep this standup concise. + We'll go in the usual order: what you did yesterday, + what you plan to do today, and any blockers. Alex, kick us off. +Alex: Morning! Yesterday, I wrapped up the UI for the user dashboard. +The new charts and widgets are now responsive. +I also had a sync with the design team to ensure the final touchups are in +line with the brand guidelines. Today, I'll start integrating the frontend with +the new API endpoints Rhea was working on. +The only blocker is waiting for some final API documentation, +but I guess Rhea can update on that. +Rhea: Hey, all! Yep, about the API documentation - I completed the majority of + the backend work for user data retrieval yesterday. + The endpoints are mostly set up, but I need to do a bit more testing today. + I'll finalize the API documentation by noon, so that should unblock Alex. + After that, I’ll be working on optimizing the database queries + for faster data fetching. No other blockers on my end. +Sam: Great, thanks Rhea. Do reach out if you need any testing assistance + or if there are any hitches with the database. + Now, my update: Yesterday, I coordinated with the client to get clarity + on some feature requirements. Today, I'll be updating our project roadmap + and timelines based on their feedback. Additionally, I'll be sitting with + the QA team in the afternoon for preliminary testing. + Blocker: I might need both of you to be available for a quick call + in case the client wants to discuss the changes live. +Alex: Sounds good, Sam. Just let us know a little in advance for the call. +Rhea: Agreed. We can make time for that. +Sam: Perfect! Let's keep the momentum going. Reach out if there are any +sudden issues or support needed. Have a productive day! +Alex: You too. +Rhea: Thanks, bye!""" + llm = Nebula(nebula_api_key="") - template = """Identify the {count} main objectives or goals mentioned in this -context concisely in less points. Emphasize on key intents.""" - prompt = PromptTemplate.from_template(template) + instruction = """Identify the main objectives mentioned in this +conversation.""" + prompt = PromptTemplate.from_template("{instruction}\n{conversation}") llm_chain = LLMChain(prompt=prompt, llm=llm) - output = llm_chain.run(count="five") - + output = llm_chain.run(instruction=instruction, conversation=conversation) assert isinstance(output, str)