Fixes to the Nebula LLM Integration (#8918)

This addresses some issues with introducing the Nebula LLM to LangChain
in this PR:
https://github.com/langchain-ai/langchain/pull/8876

This fixes the following:
- Removes `SYMBLAI` from variable names
- Fixes bug with `Bearer` for the API KEY


Thanks again in advance for your help!
cc: @hwchase17, @baskaryan

---------

Co-authored-by: dvonthenen <david.vonthenen@gmail.com>
pull/8926/head
David vonThenen 1 year ago committed by GitHub
parent d1e305028f
commit bf4a112aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,9 +34,9 @@
"source": [
"import os\n",
"\n",
"os.environ[\"SYMBLAI_NEBULA_SERVICE_URL\"] = SYMBLAI_NEBULA_SERVICE_URL\n",
"os.environ[\"SYMBLAI_NEBULA_SERVICE_PATH\"] = SYMBLAI_NEBULA_SERVICE_PATH\n",
"os.environ[\"SYMBLAI_NEBULA_SERVICE_TOKEN\"] = SYMBLAI_NEBULA_SERVICE_TOKEN"
"os.environ[\"NEBULA_SERVICE_URL\"] = NEBULA_SERVICE_URL\n",
"os.environ[\"NEBULA_SERVICE_PATH\"] = NEBULA_SERVICE_PATH\n",
"os.environ[\"NEBULA_SERVICE_API_KEY\"] = NEBULA_SERVICE_API_KEY"
]
},
{

@ -9,8 +9,8 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
DEFAULT_SYMBLAI_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH = "/v1/model/generate"
DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate"
logger = logging.getLogger(__name__)
@ -18,8 +18,8 @@ logger = logging.getLogger(__name__)
class Nebula(LLM):
"""Nebula Service models.
To use, you should have the environment variable ``SYMBLAI_NEBULA_SERVICE_URL``,
``SYMBLAI_NEBULA_SERVICE_PATH`` and ``SYMBLAI_NEBULA_SERVICE_TOKEN`` set with your Nebula
To use, you should have the environment variable ``NEBULA_SERVICE_URL``,
``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula
Service, or pass it as a named parameter to the constructor.
Example:
@ -30,21 +30,8 @@ class Nebula(LLM):
nebula = Nebula(
nebula_service_url="SERVICE_URL",
nebula_service_path="SERVICE_ROUTE",
nebula_service_token="SERVICE_TOKEN",
nebula_api_key="SERVICE_TOKEN",
)
# Use Ray for distributed processing
import ray
prompt_list=[]
@ray.remote
def send_query(llm, prompt):
resp = llm(prompt)
return resp
futures = [send_query.remote(nebula, prompt) for prompt in prompt_list]
results = ray.get(futures)
""" # noqa: E501
"""Key/value arguments to pass to the model. Reserved for future use"""
@ -53,7 +40,7 @@ class Nebula(LLM):
"""Optional"""
nebula_service_url: Optional[str] = None
nebula_service_path: Optional[str] = None
nebula_service_token: Optional[str] = None
nebula_api_key: Optional[str] = None
conversation: str = ""
return_scores: Optional[str] = "false"
max_new_tokens: Optional[int] = 2048
@ -69,20 +56,21 @@ class Nebula(LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
nebula_service_url = get_from_dict_or_env(
values, "nebula_service_url", "SYMBLAI_NEBULA_SERVICE_URL"
values,
"nebula_service_url",
"NEBULA_SERVICE_URL",
DEFAULT_NEBULA_SERVICE_URL,
)
nebula_service_path = get_from_dict_or_env(
values, "nebula_service_path", "SYMBLAI_NEBULA_SERVICE_PATH"
values,
"nebula_service_path",
"NEBULA_SERVICE_PATH",
DEFAULT_NEBULA_SERVICE_PATH,
)
nebula_service_token = get_from_dict_or_env(
values, "nebula_service_token", "SYMBLAI_NEBULA_SERVICE_TOKEN"
nebula_api_key = get_from_dict_or_env(
values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", ""
)
if len(nebula_service_url) == 0:
nebula_service_url = DEFAULT_SYMBLAI_NEBULA_SERVICE_URL
if len(nebula_service_path) == 0:
nebula_service_path = DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH
if nebula_service_url.endswith("/"):
nebula_service_url = nebula_service_url[:-1]
if not nebula_service_path.startswith("/"):
@ -94,7 +82,7 @@ class Nebula(LLM):
nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}"
headers = {
"Content-Type": "application/json",
"ApiKey": f"Bearer {nebula_service_token}",
"ApiKey": "{nebula_api_key}",
}
requests.get(nebula_service_endpoint, headers=headers)
except requests.exceptions.RequestException as e:
@ -103,7 +91,7 @@ class Nebula(LLM):
values["nebula_service_url"] = nebula_service_url
values["nebula_service_path"] = nebula_service_path
values["nebula_service_token"] = nebula_service_token
values["nebula_api_key"] = nebula_api_key
return values
@ -147,7 +135,7 @@ class Nebula(LLM):
headers = {
"Content-Type": "application/json",
"ApiKey": f"Bearer {self.nebula_service_token}",
"ApiKey": f"{self.nebula_api_key}",
}
body = {

Loading…
Cancel
Save