Improvements in Nebula LLM (#9226)

- Description: Added improvements in Nebula LLM to perform auto-retry;
more generation parameters supported. Conversation is no longer required
to be passed in the LLM object. Examples are updated.
  - Issue: N/A
  - Dependencies: N/A
  - Tag maintainer: @baskaryan 
  - Twitter handle: symbldotai

---------

Co-authored-by: toshishjawale <toshish@symbl.ai>
pull/9163/head^2
Toshish Jawale 11 months ago committed by GitHub
parent 358562769a
commit 852722ea45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,80 +1,83 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "9597802c",
"metadata": {},
"source": [ "source": [
"# Nebula\n", "# Nebula (Symbl.ai)\n",
"[Nebula](https://symbl.ai/nebula/) is a large language model (LLM) built by [Symbl.ai](https://symbl.ai). It is trained to perform generative tasks on human conversations. Nebula excels at modeling the nuanced details of a conversation and performing tasks on the conversation.\n",
"\n", "\n",
"[Nebula](https://symbl.ai/nebula/) is a fully-managed Conversation platform, on which you can build, deploy, and manage scalable AI applications.\n", "Nebula documentation: https://docs.symbl.ai/docs/nebula-llm\n",
"\n", "\n",
"This example goes over how to use LangChain to interact with the [Nebula platform](https://docs.symbl.ai/docs/nebula-llm-overview). \n", "This example goes over how to use LangChain to interact with the [Nebula platform](https://docs.symbl.ai/docs/nebula-llm)."
"\n", ],
"It will send the requests to Nebula Service endpoint, which concatenates `SYMBLAI_NEBULA_SERVICE_URL` and `SYMBLAI_NEBULA_SERVICE_PATH`, with a token defined in `SYMBLAI_NEBULA_SERVICE_TOKEN`" "metadata": {
] "collapsed": false
},
"id": "bb8cd830db4a004e"
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "f15ebe0d",
"metadata": {},
"source": [ "source": [
"### Integrate with a LLMChain" "Make sure you have API Key with you. If you don't have one please [request one](https://info.symbl.ai/Nebula_Private_Beta.html)."
] ],
"metadata": {
"collapsed": false
},
"id": "519570b6539aa18c"
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "5472a7cd-af26-48ca-ae9b-5f6ae73c74d2",
"metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from langchain.llms.symblai_nebula import Nebula\n",
"\n", "\n",
"os.environ[\"NEBULA_SERVICE_URL\"] = NEBULA_SERVICE_URL\n", "llm = Nebula(nebula_api_key='<your_api_key>')"
"os.environ[\"NEBULA_SERVICE_PATH\"] = NEBULA_SERVICE_PATH\n", ],
"os.environ[\"NEBULA_SERVICE_API_KEY\"] = NEBULA_SERVICE_API_KEY" "metadata": {
] "collapsed": false
},
"id": "9f47bef45880aece"
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": null, "source": [
"id": "6fb585dd", "Use a conversation transcript and instruction to construct a prompt."
],
"metadata": { "metadata": {
"tags": [] "collapsed": false
}, },
"outputs": [], "id": "88c6a516ef51c74b"
"source": [
"from langchain.llms import OpenLLM\n",
"\n",
"llm = OpenLLM(\n",
" conversation=\"<Drop your text conversation that you want to ask Nebula to analyze here>\",\n",
")"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "035dea0f",
"metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain import PromptTemplate, LLMChain\n", "from langchain import PromptTemplate, LLMChain\n",
"\n", "\n",
"template = \"Identify the {count} main objectives or goals mentioned in this context concisely in less points. Emphasize on key intents.\"\n", "conversation = \"\"\"Sam: Good morning, team! Let's keep this standup concise. We'll go in the usual order: what you did yesterday, what you plan to do today, and any blockers. Alex, kick us off.\n",
"Alex: Morning! Yesterday, I wrapped up the UI for the user dashboard. The new charts and widgets are now responsive. I also had a sync with the design team to ensure the final touchups are in line with the brand guidelines. Today, I'll start integrating the frontend with the new API endpoints Rhea was working on. The only blocker is waiting for some final API documentation, but I guess Rhea can update on that.\n",
"Rhea: Hey, all! Yep, about the API documentation - I completed the majority of the backend work for user data retrieval yesterday. The endpoints are mostly set up, but I need to do a bit more testing today. I'll finalize the API documentation by noon, so that should unblock Alex. After that, Ill be working on optimizing the database queries for faster data fetching. No other blockers on my end.\n",
"Sam: Great, thanks Rhea. Do reach out if you need any testing assistance or if there are any hitches with the database. Now, my update: Yesterday, I coordinated with the client to get clarity on some feature requirements. Today, I'll be updating our project roadmap and timelines based on their feedback. Additionally, I'll be sitting with the QA team in the afternoon for preliminary testing. Blocker: I might need both of you to be available for a quick call in case the client wants to discuss the changes live.\n",
"Alex: Sounds good, Sam. Just let us know a little in advance for the call.\n",
"Rhea: Agreed. We can make time for that.\n",
"Sam: Perfect! Let's keep the momentum going. Reach out if there are any sudden issues or support needed. Have a productive day!\n",
"Alex: You too.\n",
"Rhea: Thanks, bye!\"\"\"\n",
"\n",
"instruction = \"Identify the main objectives mentioned in this conversation.\"\n",
"\n", "\n",
"prompt = PromptTemplate(template=template, input_variables=[\"count\"])\n", "prompt = PromptTemplate.from_template(\"{instruction}\\n{conversation}\")\n",
"\n", "\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n", "llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n", "\n",
"generated = llm_chain.run(count=\"five\")\n", "llm_chain.run(instruction=instruction, conversation=conversation)"
"print(generated)" ],
] "metadata": {
"collapsed": false
},
"id": "5977ccc2d4432624"
} }
], ],
"metadata": { "metadata": {

@ -5,16 +5,14 @@ It is broken into two parts: installation and setup, and then references to spec
## Installation and Setup ## Installation and Setup
- Get an Nebula API Key and set as environment variables (`SYMBLAI_NEBULA_SERVICE_URL`, `SYMBLAI_NEBULA_SERVICE_PATH`, `SYMBLAI_NEBULA_SERVICE_TOKEN`) - Get an [Nebula API Key](https://info.symbl.ai/Nebula_Private_Beta.html) and set as environment variable `NEBULA_API_KEY`
- Sign up for a FREE Symbl.ai/Nebula Account: [https://nebula.symbl.ai/playground/](https://nebula.symbl.ai/playground/) - Please see the [Nebula documentation](https://docs.symbl.ai/docs/nebula-llm) for more details.
- Please see the [Nebula documentation](https://docs.symbl.ai/docs/nebula-llm-overview) for more details. - No time? Visit the [Nebula Quickstart Guide](https://docs.symbl.ai/docs/nebula-quickstart).
- No time? Visit the [Nebula Quickstart Guide](https://docs.symbl.ai/docs/nebula-quickstart).
## Wrappers
### LLM ### LLM
There exists an Nebula LLM wrapper, which you can access with There exists an Nebula LLM wrapper, which you can access with
```python ```python
from langchain.llms import Nebula from langchain.llms import Nebula
llm = Nebula()
``` ```

@ -1,8 +1,17 @@
import json
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Callable, Dict, List, Mapping, Optional
import requests import requests
from pydantic_v1 import Extra, root_validator from pydantic import Extra, root_validator
from requests import ConnectTimeout, ReadTimeout, RequestException
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -19,7 +28,7 @@ class Nebula(LLM):
"""Nebula Service models. """Nebula Service models.
To use, you should have the environment variable ``NEBULA_SERVICE_URL``, To use, you should have the environment variable ``NEBULA_SERVICE_URL``,
``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula ``NEBULA_SERVICE_PATH`` and ``NEBULA_API_KEY`` set with your Nebula
Service, or pass it as a named parameter to the constructor. Service, or pass it as a named parameter to the constructor.
Example: Example:
@ -28,9 +37,9 @@ class Nebula(LLM):
from langchain.llms import Nebula from langchain.llms import Nebula
nebula = Nebula( nebula = Nebula(
nebula_service_url="SERVICE_URL", nebula_service_url="NEBULA_SERVICE_URL",
nebula_service_path="SERVICE_ROUTE", nebula_service_path="NEBULA_SERVICE_PATH",
nebula_api_key="SERVICE_TOKEN", nebula_api_key="NEBULA_API_KEY",
) )
""" # noqa: E501 """ # noqa: E501
@ -38,14 +47,19 @@ class Nebula(LLM):
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Optional""" """Optional"""
nebula_service_url: Optional[str] = None nebula_service_url: Optional[str] = None
nebula_service_path: Optional[str] = None nebula_service_path: Optional[str] = None
nebula_api_key: Optional[str] = None nebula_api_key: Optional[str] = None
conversation: str = "" model: Optional[str] = None
return_scores: Optional[str] = "false" max_new_tokens: Optional[int] = 128
max_new_tokens: Optional[int] = 2048 temperature: Optional[float] = 0.6
top_k: Optional[float] = 2 top_p: Optional[float] = 0.95
penalty_alpha: Optional[float] = 0.1 repetition_penalty: Optional[float] = 1.0
top_k: Optional[int] = 0
penalty_alpha: Optional[float] = 0.0
stop_sequences: Optional[List[str]] = None
max_retries: Optional[int] = 10
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -68,7 +82,7 @@ class Nebula(LLM):
DEFAULT_NEBULA_SERVICE_PATH, DEFAULT_NEBULA_SERVICE_PATH,
) )
nebula_api_key = get_from_dict_or_env( nebula_api_key = get_from_dict_or_env(
values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", "" values, "nebula_api_key", "NEBULA_API_KEY", None
) )
if nebula_service_url.endswith("/"): if nebula_service_url.endswith("/"):
@ -76,25 +90,24 @@ class Nebula(LLM):
if not nebula_service_path.startswith("/"): if not nebula_service_path.startswith("/"):
nebula_service_path = "/" + nebula_service_path nebula_service_path = "/" + nebula_service_path
""" TODO: Future login"""
"""
try:
nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}"
headers = {
"Content-Type": "application/json",
"ApiKey": "{nebula_api_key}",
}
requests.get(nebula_service_endpoint, headers=headers)
except requests.exceptions.RequestException as e:
raise ValueError(e)
"""
values["nebula_service_url"] = nebula_service_url values["nebula_service_url"] = nebula_service_url
values["nebula_service_path"] = nebula_service_path values["nebula_service_path"] = nebula_service_path
values["nebula_api_key"] = nebula_api_key values["nebula_api_key"] = nebula_api_key
return values return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"repetition_penalty": self.repetition_penalty,
"penalty_alpha": self.penalty_alpha,
}
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
@ -103,7 +116,6 @@ class Nebula(LLM):
"nebula_service_url": self.nebula_service_url, "nebula_service_url": self.nebula_service_url,
"nebula_service_path": self.nebula_service_path, "nebula_service_path": self.nebula_service_path,
**{"model_kwargs": _model_kwargs}, **{"model_kwargs": _model_kwargs},
"conversation": self.conversation,
} }
@property @property
@ -111,6 +123,25 @@ class Nebula(LLM):
"""Return type of llm.""" """Return type of llm."""
return "nebula" return "nebula"
def _invocation_params(
self, stop_sequences: Optional[List[str]], **kwargs: Any
) -> dict:
params = self._default_params
if self.stop_sequences is not None and stop_sequences is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop_sequences is not None:
params["stop_sequences"] = self.stop_sequences
else:
params["stop_sequences"] = stop_sequences
return {**params, **kwargs}
@staticmethod
def _process_response(response: Any, stop: Optional[List[str]]) -> str:
text = response["output"]["text"]
if stop:
text = enforce_stop_tokens(text, stop)
return text
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -128,57 +159,84 @@ class Nebula(LLM):
.. code-block:: python .. code-block:: python
response = nebula("Tell me a joke.") response = nebula("Tell me a joke.")
""" """
params = self._invocation_params(stop, **kwargs)
_model_kwargs = self.model_kwargs or {} prompt = prompt.strip()
if "\n" in prompt:
nebula_service_endpoint = f"{self.nebula_service_url}{self.nebula_service_path}" instruction = prompt.split("\n")[0]
conversation = "\n".join(prompt.split("\n")[1:])
headers = { else:
"Content-Type": "application/json", raise ValueError("Prompt must contain instruction and conversation.")
"ApiKey": f"{self.nebula_api_key}",
response = completion_with_retry(
self,
instruction=instruction,
conversation=conversation,
params=params,
url=f"{self.nebula_service_url}{self.nebula_service_path}",
)
_stop = params.get("stop_sequences")
return self._process_response(response, _stop)
def make_request(
self: Nebula,
instruction: str,
conversation: str,
url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}",
params: Dict = {},
) -> Any:
"""Generate text from the model."""
headers = {
"Content-Type": "application/json",
"ApiKey": f"{self.nebula_api_key}",
}
body = {
"prompt": {
"instruction": instruction,
"conversation": {"text": f"{conversation}"},
} }
}
body = { # add params to body
"prompt": { for key, value in params.items():
"instruction": prompt, body[key] = value
"conversation": {"text": f"{self.conversation}"},
},
"return_scores": self.return_scores,
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"penalty_alpha": self.penalty_alpha,
}
if len(self.conversation) == 0: # make request
raise ValueError("Error conversation is empty.") response = requests.post(url, headers=headers, json=body)
logger.debug(f"NEBULA _model_kwargs: {_model_kwargs}") if response.status_code != 200:
logger.debug(f"NEBULA body: {body}") raise Exception(
logger.debug(f"NEBULA kwargs: {kwargs}") f"Request failed with status code {response.status_code}"
logger.debug(f"NEBULA conversation: {self.conversation}") f" and message {response.text}"
)
# call API return json.loads(response.text)
try:
response = requests.post(
nebula_service_endpoint, headers=headers, json=body
)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
logger.debug(f"NEBULA response: {response}")
if response.status_code != 200: def _create_retry_decorator(llm: Nebula) -> Callable[[Any], Any]:
raise ValueError( min_seconds = 4
f"Error returned by service, status code {response.status_code}" max_seconds = 10
) # Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterward
max_retries = llm.max_retries if llm.max_retries is not None else 3
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type((RequestException, ConnectTimeout, ReadTimeout))
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
""" get the result """
text = response.text
""" enforce stop """ def completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
if stop is not None: """Use tenacity to retry the completion call."""
# This is required since the stop tokens retry_decorator = _create_retry_decorator(llm)
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text @retry_decorator
def _completion_with_retry(**_kwargs: Any) -> Any:
return make_request(llm, **_kwargs)
return _completion_with_retry(**kwargs)

@ -1,56 +1,46 @@
"""Test Nebula API wrapper.""" """Test Nebula API wrapper."""
from langchain import LLMChain, PromptTemplate from langchain import LLMChain, PromptTemplate
from langchain.llms.symblai_nebula import Nebula from langchain.llms.symblai_nebula import Nebula
def test_symblai_nebula_call() -> None: def test_symblai_nebula_call() -> None:
"""Test valid call to Nebula.""" """Test valid call to Nebula."""
conversation = """Speaker 1: Thank you for calling ABC, company.Speaker 1: My name conversation = """Sam: Good morning, team! Let's keep this standup concise.
is Mary.Speaker 1: How may I help you?Speaker 2: Today?Speaker 1: All right, We'll go in the usual order: what you did yesterday,
Madam.Speaker 1: I really apologize for this inconvenient.Speaker 1: I will be happy what you plan to do today, and any blockers. Alex, kick us off.
to assist you in this matter.Speaker 1: Could you please offer me Yuri your account Alex: Morning! Yesterday, I wrapped up the UI for the user dashboard.
number?Speaker 1: Alright Madam, thank you very much.Speaker 1: Let me check that The new charts and widgets are now responsive.
for confirmation.Speaker 1: Did you say 534 00 365?Speaker 2: 48?Speaker 1: Very good I also had a sync with the design team to ensure the final touchups are in
man.Speaker 1: Now for verification purposes, can I please get your full?Speaker line with the brand guidelines. Today, I'll start integrating the frontend with
2: Name?Speaker 1: Alright, thank you.Speaker 1: Very much.Speaker 1: Madam.Speaker the new API endpoints Rhea was working on.
1: Can I, please get your birthdate now?Speaker 1: I am sorry madam.Speaker 1: I The only blocker is waiting for some final API documentation,
didn't make this clear is for verification.Speaker 1: Purposes is the company but I guess Rhea can update on that.
request.Speaker 1: The system requires me, your name, your complete name and your Rhea: Hey, all! Yep, about the API documentation - I completed the majority of
date of.Speaker 2: Birth.Speaker 2: Alright, thank you very much madam.Speaker 1: the backend work for user data retrieval yesterday.
All right.Speaker 1: Thank you very much, Madam.Speaker 1: Thank you for that The endpoints are mostly set up, but I need to do a bit more testing today.
information.Speaker 1: Let me check what happens.Speaker 2: Here.Speaker 1: So I'll finalize the API documentation by noon, so that should unblock Alex.
according to our data space them, you did pay your last bill last August the 12, After that, Ill be working on optimizing the database queries
which was two days ago in one of our Affiliated payment centers.Speaker 1: So, at the for faster data fetching. No other blockers on my end.
moment you currently, We have zero balance.Speaker 1: So however, the bill that you Sam: Great, thanks Rhea. Do reach out if you need any testing assistance
received was generated a week before you made the pavement, this is reason why you or if there are any hitches with the database.
already make this payment, have not been reflected yet.Speaker 1: So what we do in Now, my update: Yesterday, I coordinated with the client to get clarity
this case, you just simply disregard the amount indicated in the field and you on some feature requirements. Today, I'll be updating our project roadmap
continue to enjoy our service man.Speaker 1: Sure, Madam.Speaker 1: And I am sure and timelines based on their feedback. Additionally, I'll be sitting with
you need your cell phone for everything for life, right?Speaker 1: So I really the QA team in the afternoon for preliminary testing.
apologize for this inconvenience.Speaker 1: And let me tell you that delays in the Blocker: I might need both of you to be available for a quick call
bill is usually caused by delays in our Courier Service.Speaker 1: That is to say in case the client wants to discuss the changes live.
that it'''s a problem, not with the company, but with a courier service, For a more Alex: Sounds good, Sam. Just let us know a little in advance for the call.
updated, feel of your account, you can visit our website and log into your account, Rhea: Agreed. We can make time for that.
and they'''re in the system.Speaker 1: On the website, you are going to have the Sam: Perfect! Let's keep the momentum going. Reach out if there are any
possibility to pay the bill.Speaker 1: That is more.Speaker 2: Updated.Speaker 2: sudden issues or support needed. Have a productive day!
Of course, Madam I can definitely assist you with that.Speaker 2: Once you have, Alex: You too.
you want to see your bill updated, please go to www.hsn BC campus, any.com after Rhea: Thanks, bye!"""
that.Speaker 2: You will see in the tale.Speaker 1: All right corner.Speaker 1: So llm = Nebula(nebula_api_key="<your_api_key>")
you're going to see a pay now button.Speaker 1: Please click on the pay now button
and the serve.Speaker 1: The system is going to ask you for personal
information.Speaker 1: Such as your first name, your ID account, your the number of
your account, your email address, and your phone number once you complete this personal
information."""
llm = Nebula(
conversation=conversation,
)
template = """Identify the {count} main objectives or goals mentioned in this instruction = """Identify the main objectives mentioned in this
context concisely in less points. Emphasize on key intents.""" conversation."""
prompt = PromptTemplate.from_template(template) prompt = PromptTemplate.from_template("{instruction}\n{conversation}")
llm_chain = LLMChain(prompt=prompt, llm=llm) llm_chain = LLMChain(prompt=prompt, llm=llm)
output = llm_chain.run(count="five") output = llm_chain.run(instruction=instruction, conversation=conversation)
assert isinstance(output, str) assert isinstance(output, str)

Loading…
Cancel
Save