@ -9,8 +9,8 @@ from langchain.llms.base import LLM
from langchain . llms . utils import enforce_stop_tokens
from langchain . utils import get_from_dict_or_env
DEFAULT_ SYMBLAI_ NEBULA_SERVICE_URL = " https://api-nebula.symbl.ai "
DEFAULT_ SYMBLAI_ NEBULA_SERVICE_PATH = " /v1/model/generate "
DEFAULT_ NEBULA_SERVICE_URL = " https://api-nebula.symbl.ai "
DEFAULT_ NEBULA_SERVICE_PATH = " /v1/model/generate "
logger = logging . getLogger ( __name__ )
@ -18,8 +18,8 @@ logger = logging.getLogger(__name__)
class Nebula ( LLM ) :
""" Nebula Service models.
To use , you should have the environment variable ` ` SYMBLAI_ NEBULA_SERVICE_URL` ` ,
` ` SYMBLAI_ NEBULA_SERVICE_PATH` ` and ` ` SYMBLAI_NEBULA_SERVICE_TOKEN ` ` set with your Nebula
To use , you should have the environment variable ` ` NEBULA_SERVICE_URL` ` ,
` ` NEBULA_SERVICE_PATH` ` and ` ` NEBULA_SERVICE_API_KEY ` ` set with your Nebula
Service , or pass it as a named parameter to the constructor .
Example :
@ -30,21 +30,8 @@ class Nebula(LLM):
nebula = Nebula (
nebula_service_url = " SERVICE_URL " ,
nebula_service_path = " SERVICE_ROUTE " ,
nebula_ service_token = " SERVICE_TOKEN " ,
nebula_ api_key = " SERVICE_TOKEN " ,
)
# Use Ray for distributed processing
import ray
prompt_list = [ ]
@ray.remote
def send_query ( llm , prompt ) :
resp = llm ( prompt )
return resp
futures = [ send_query . remote ( nebula , prompt ) for prompt in prompt_list ]
results = ray . get ( futures )
""" # noqa: E501
""" Key/value arguments to pass to the model. Reserved for future use """
@ -53,7 +40,7 @@ class Nebula(LLM):
""" Optional """
nebula_service_url : Optional [ str ] = None
nebula_service_path : Optional [ str ] = None
nebula_ service_token : Optional [ str ] = None
nebula_ api_key : Optional [ str ] = None
conversation : str = " "
return_scores : Optional [ str ] = " false "
max_new_tokens : Optional [ int ] = 2048
@ -69,20 +56,21 @@ class Nebula(LLM):
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key and python package exists in environment. """
nebula_service_url = get_from_dict_or_env (
values , " nebula_service_url " , " SYMBLAI_NEBULA_SERVICE_URL "
values ,
" nebula_service_url " ,
" NEBULA_SERVICE_URL " ,
DEFAULT_NEBULA_SERVICE_URL ,
)
nebula_service_path = get_from_dict_or_env (
values , " nebula_service_path " , " SYMBLAI_NEBULA_SERVICE_PATH "
values ,
" nebula_service_path " ,
" NEBULA_SERVICE_PATH " ,
DEFAULT_NEBULA_SERVICE_PATH ,
)
nebula_service_token = get_from_dict_or_env (
values , " nebula_service_token " , " SYMBLAI_NEBULA_SERVICE_TOKEN "
nebula_ api_key = get_from_dict_or_env (
values , " nebula_ api_key" , " NEBULA_SERVICE_API_KEY " , " "
)
if len ( nebula_service_url ) == 0 :
nebula_service_url = DEFAULT_SYMBLAI_NEBULA_SERVICE_URL
if len ( nebula_service_path ) == 0 :
nebula_service_path = DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH
if nebula_service_url . endswith ( " / " ) :
nebula_service_url = nebula_service_url [ : - 1 ]
if not nebula_service_path . startswith ( " / " ) :
@ -94,7 +82,7 @@ class Nebula(LLM):
nebula_service_endpoint = f " { nebula_service_url } { nebula_service_path } "
headers = {
" Content-Type " : " application/json " ,
" ApiKey " : f" Bearer { nebula_service_token }" ,
" ApiKey " : " {nebula_api_key }" ,
}
requests . get ( nebula_service_endpoint , headers = headers )
except requests . exceptions . RequestException as e :
@ -103,7 +91,7 @@ class Nebula(LLM):
values [ " nebula_service_url " ] = nebula_service_url
values [ " nebula_service_path " ] = nebula_service_path
values [ " nebula_ service_token" ] = nebula_service_token
values [ " nebula_ api_key" ] = nebula_api_key
return values
@ -147,7 +135,7 @@ class Nebula(LLM):
headers = {
" Content-Type " : " application/json " ,
" ApiKey " : f " Bearer { self . nebula_service_token } " ,
" ApiKey " : f " { self . nebula_api_key } " ,
}
body = {