mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
526 lines
20 KiB
Python
526 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import aiohttp
|
|
import requests
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.pydantic_v1 import (
|
|
BaseModel,
|
|
Field,
|
|
PrivateAttr,
|
|
SecretStr,
|
|
root_validator,
|
|
)
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
from requests.models import Response
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NVEModel(BaseModel):
|
|
|
|
"""
|
|
Underlying Client for interacting with the AI Foundation Model Function API.
|
|
Leveraged by the NVIDIABaseModel to provide a simple requests-oriented interface.
|
|
Direct abstraction over NGC-recommended streaming/non-streaming Python solutions.
|
|
|
|
NOTE: Models in the playground does not currently support raw text continuation.
|
|
"""
|
|
|
|
## Core defaults. These probably should not be changed
|
|
fetch_url_format: str = Field("https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/")
|
|
call_invoke_base: str = Field("https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions")
|
|
get_session_fn: Callable = Field(requests.Session)
|
|
get_asession_fn: Callable = Field(aiohttp.ClientSession)
|
|
|
|
nvidia_api_key: SecretStr = Field(
|
|
...,
|
|
description="API key for NVIDIA Foundation Endpoints. Starts with `nvapi-`",
|
|
)
|
|
is_staging: bool = Field(False, description="Whether to use staging API")
|
|
|
|
## Generation arguments
|
|
max_tries: int = Field(5, ge=1)
|
|
headers_tmpl: dict = Field(
|
|
...,
|
|
description="Headers template for API calls."
|
|
" Should contain `call` and `stream` keys.",
|
|
)
|
|
_available_functions: Optional[List[dict]] = PrivateAttr(default=None)
|
|
_available_models: Optional[dict] = PrivateAttr(default=None)
|
|
|
|
@property
|
|
def headers(self) -> dict:
|
|
"""Return headers with API key injected"""
|
|
headers_ = self.headers_tmpl.copy()
|
|
for header in headers_.values():
|
|
if "{nvidia_api_key}" in header["Authorization"]:
|
|
header["Authorization"] = header["Authorization"].format(
|
|
nvidia_api_key=self.nvidia_api_key.get_secret_value(),
|
|
)
|
|
return headers_
|
|
|
|
@root_validator(pre=True)
|
|
def validate_model(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate and update model arguments, including API key and formatting"""
|
|
values["nvidia_api_key"] = get_from_dict_or_env(
|
|
values,
|
|
"nvidia_api_key",
|
|
"NVIDIA_API_KEY",
|
|
)
|
|
if "nvapi-" not in values.get("nvidia_api_key", ""):
|
|
raise ValueError("Invalid NVAPI key detected. Should start with `nvapi-`")
|
|
is_staging = "nvapi-stg-" in values["nvidia_api_key"]
|
|
values["is_staging"] = is_staging
|
|
if "headers_tmpl" not in values:
|
|
values["headers_tmpl"] = {
|
|
"call": {
|
|
"Authorization": "Bearer {nvidia_api_key}",
|
|
"Accept": "application/json",
|
|
},
|
|
"stream": {
|
|
"Authorization": "Bearer {nvidia_api_key}",
|
|
"Accept": "text/event-stream",
|
|
"content-type": "application/json",
|
|
},
|
|
}
|
|
|
|
values["fetch_url_format"] = cls._stagify(
|
|
is_staging,
|
|
values.get(
|
|
"fetch_url_format", "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/"
|
|
),
|
|
)
|
|
values["call_invoke_base"] = cls._stagify(
|
|
is_staging,
|
|
values.get(
|
|
"call_invoke_base",
|
|
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions",
|
|
),
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def available_models(self) -> dict:
|
|
"""List the available models that can be invoked."""
|
|
if self._available_models is not None:
|
|
return self._available_models
|
|
live_fns = [v for v in self.available_functions if v.get("status") == "ACTIVE"]
|
|
self._available_models = {v["name"]: v["id"] for v in live_fns}
|
|
return self._available_models
|
|
|
|
@property
|
|
def available_functions(self) -> List[dict]:
|
|
"""List the available functions that can be invoked."""
|
|
if self._available_functions is not None:
|
|
return self._available_functions
|
|
invoke_url = self._stagify(
|
|
self.is_staging, "https://api.nvcf.nvidia.com/v2/nvcf/functions"
|
|
)
|
|
query_res = self.query(invoke_url)
|
|
if "functions" not in query_res:
|
|
raise ValueError(
|
|
f"Unexpected response when querying {invoke_url}\n{query_res}"
|
|
)
|
|
self._available_functions = query_res["functions"]
|
|
return self._available_functions
|
|
|
|
@classmethod
|
|
def _stagify(cls, is_staging: bool, path: str) -> str:
|
|
"""Helper method to switch between staging and production endpoints"""
|
|
if is_staging and "stg.api" not in path:
|
|
return path.replace("api.", "stg.api.")
|
|
if not is_staging and "stg.api" in path:
|
|
return path.replace("stg.api.", "api.")
|
|
return path
|
|
|
|
####################################################################################
|
|
## Core utilities for posting and getting from NV Endpoints
|
|
|
|
def _post(self, invoke_url: str, payload: dict = {}) -> Tuple[Response, Any]:
|
|
"""Method for posting to the AI Foundation Model Function API."""
|
|
call_inputs = {
|
|
"url": invoke_url,
|
|
"headers": self.headers["call"],
|
|
"json": payload,
|
|
"stream": False,
|
|
}
|
|
session = self.get_session_fn()
|
|
response = session.post(**call_inputs)
|
|
self._try_raise(response)
|
|
return response, session
|
|
|
|
def _get(self, invoke_url: str, payload: dict = {}) -> Tuple[Response, Any]:
|
|
"""Method for getting from the AI Foundation Model Function API."""
|
|
last_inputs = {
|
|
"url": invoke_url,
|
|
"headers": self.headers["call"],
|
|
"json": payload,
|
|
"stream": False,
|
|
}
|
|
session = self.get_session_fn()
|
|
last_response = session.get(**last_inputs)
|
|
self._try_raise(last_response)
|
|
return last_response, session
|
|
|
|
def _wait(self, response: Response, session: Any) -> Response:
|
|
"""Wait for a response from API after an initial response is made."""
|
|
i = 1
|
|
while response.status_code == 202:
|
|
request_id = response.headers.get("NVCF-REQID", "")
|
|
response = session.get(
|
|
self.fetch_url_format + request_id,
|
|
headers=self.headers["call"],
|
|
)
|
|
if response.status_code == 202:
|
|
try:
|
|
body = response.json()
|
|
except ValueError:
|
|
body = str(response)
|
|
if i > self.max_tries:
|
|
raise ValueError(f"Failed to get response with {i} tries: {body}")
|
|
self._try_raise(response)
|
|
return response
|
|
|
|
def _try_raise(self, response: Response) -> None:
|
|
"""Try to raise an error from a response"""
|
|
try:
|
|
response.raise_for_status()
|
|
except requests.HTTPError as e:
|
|
try:
|
|
rd = response.json()
|
|
except json.JSONDecodeError:
|
|
rd = response.__dict__
|
|
rd = rd.get("_content", rd)
|
|
if isinstance(rd, bytes):
|
|
rd = rd.decode("utf-8")[5:] ## remove "data:" prefix
|
|
try:
|
|
rd = json.loads(rd)
|
|
except Exception:
|
|
rd = {"detail": rd}
|
|
title = f"[{rd.get('status', '###')}] {rd.get('title', 'Unknown Error')}"
|
|
body = f"{rd.get('detail', rd.get('type', rd))}"
|
|
raise Exception(f"{title}\n{body}") from e
|
|
|
|
####################################################################################
|
|
## Simple query interface to show the set of model options
|
|
|
|
def query(self, invoke_url: str, payload: dict = {}) -> dict:
|
|
"""Simple method for an end-to-end get query. Returns result dictionary"""
|
|
response, session = self._get(invoke_url, payload)
|
|
response = self._wait(response, session)
|
|
output = self._process_response(response)[0]
|
|
return output
|
|
|
|
def _process_response(self, response: Union[str, Response]) -> List[dict]:
|
|
"""General-purpose response processing for single responses and streams"""
|
|
if hasattr(response, "json"): ## For single response (i.e. non-streaming)
|
|
try:
|
|
return [response.json()]
|
|
except json.JSONDecodeError:
|
|
response = str(response.__dict__)
|
|
if isinstance(response, str): ## For set of responses (i.e. streaming)
|
|
msg_list = []
|
|
for msg in response.split("\n\n"):
|
|
if "{" not in msg:
|
|
continue
|
|
msg_list += [json.loads(msg[msg.find("{") :])]
|
|
return msg_list
|
|
raise ValueError(f"Received ill-formed response: {response}")
|
|
|
|
def _get_invoke_url(
|
|
self, model_name: Optional[str] = None, invoke_url: Optional[str] = None
|
|
) -> str:
|
|
"""Helper method to get invoke URL from a model name, URL, or endpoint stub"""
|
|
if not invoke_url:
|
|
if not model_name:
|
|
raise ValueError("URL or model name must be specified to invoke")
|
|
if model_name in self.available_models:
|
|
invoke_url = self.available_models[model_name]
|
|
elif f"playground_{model_name}" in self.available_models:
|
|
invoke_url = self.available_models[f"playground_{model_name}"]
|
|
else:
|
|
available_models_str = "\n".join(
|
|
[f"{k} - {v}" for k, v in self.available_models.items()]
|
|
)
|
|
raise ValueError(
|
|
f"Unknown model name {model_name} specified."
|
|
"\nAvailable models are:\n"
|
|
f"{available_models_str}"
|
|
)
|
|
if not invoke_url:
|
|
# For mypy
|
|
raise ValueError("URL or model name must be specified to invoke")
|
|
# Why is this even needed?
|
|
if "http" not in invoke_url:
|
|
invoke_url = f"{self.call_invoke_base}/{invoke_url}"
|
|
return invoke_url
|
|
|
|
####################################################################################
|
|
## Generation interface to allow users to generate new values from endpoints
|
|
|
|
def get_req(
|
|
self,
|
|
model_name: Optional[str] = None,
|
|
payload: dict = {},
|
|
invoke_url: Optional[str] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
) -> Response:
|
|
"""Post to the API."""
|
|
invoke_url = self._get_invoke_url(model_name, invoke_url)
|
|
if payload.get("stream", False) is True:
|
|
payload = {**payload, "stream": False}
|
|
response, session = self._post(invoke_url, payload)
|
|
return self._wait(response, session)
|
|
|
|
def get_req_generation(
|
|
self,
|
|
model_name: Optional[str] = None,
|
|
payload: dict = {},
|
|
invoke_url: Optional[str] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
) -> dict:
|
|
"""Method for an end-to-end post query with NVE post-processing."""
|
|
response = self.get_req(model_name, payload, invoke_url)
|
|
output, _ = self.postprocess(response, stop=stop)
|
|
return output
|
|
|
|
def postprocess(
|
|
self, response: Union[str, Response], stop: Optional[Sequence[str]] = None
|
|
) -> Tuple[dict, bool]:
|
|
"""Parses a response from the AI Foundation Model Function API.
|
|
Strongly assumes that the API will return a single response.
|
|
"""
|
|
msg_list = self._process_response(response)
|
|
msg, is_stopped = self._aggregate_msgs(msg_list)
|
|
msg, is_stopped = self._early_stop_msg(msg, is_stopped, stop=stop)
|
|
return msg, is_stopped
|
|
|
|
def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
|
|
"""Dig out relevant details of aggregated message"""
|
|
content_buffer: Dict[str, Any] = dict()
|
|
content_holder: Dict[Any, Any] = dict()
|
|
is_stopped = False
|
|
for msg in msg_list:
|
|
if "choices" in msg:
|
|
## Tease out ['choices'][0]...['delta'/'message']
|
|
msg = msg.get("choices", [{}])[0]
|
|
is_stopped = msg.get("finish_reason", "") == "stop"
|
|
msg = msg.get("delta", msg.get("message", {"content": ""}))
|
|
elif "data" in msg:
|
|
## Tease out ['data'][0]...['embedding']
|
|
msg = msg.get("data", [{}])[0]
|
|
content_holder = msg
|
|
for k, v in msg.items():
|
|
if k in ("content",) and k in content_buffer:
|
|
content_buffer[k] += v
|
|
else:
|
|
content_buffer[k] = v
|
|
if is_stopped:
|
|
break
|
|
content_holder = {**content_holder, **content_buffer}
|
|
return content_holder, is_stopped
|
|
|
|
def _early_stop_msg(
|
|
self, msg: dict, is_stopped: bool, stop: Optional[Sequence[str]] = None
|
|
) -> Tuple[dict, bool]:
|
|
"""Try to early-terminate streaming or generation by iterating over stop list"""
|
|
content = msg.get("content", "")
|
|
if content and stop:
|
|
for stop_str in stop:
|
|
if stop_str and stop_str in content:
|
|
msg["content"] = content[: content.find(stop_str) + 1]
|
|
is_stopped = True
|
|
return msg, is_stopped
|
|
|
|
####################################################################################
|
|
## Streaming interface to allow you to iterate through progressive generations
|
|
|
|
def get_req_stream(
|
|
self,
|
|
model: Optional[str] = None,
|
|
payload: dict = {},
|
|
invoke_url: Optional[str] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
) -> Iterator:
|
|
invoke_url = self._get_invoke_url(model, invoke_url)
|
|
if payload.get("stream", True) is False:
|
|
payload = {**payload, "stream": True}
|
|
last_inputs = {
|
|
"url": invoke_url,
|
|
"headers": self.headers["stream"],
|
|
"json": payload,
|
|
"stream": True,
|
|
}
|
|
response = self.get_session_fn().post(**last_inputs)
|
|
self._try_raise(response)
|
|
call = self.copy()
|
|
|
|
def out_gen() -> Generator[dict, Any, Any]:
|
|
## Good for client, since it allows self.last_input
|
|
for line in response.iter_lines():
|
|
if line and line.strip() != b"data: [DONE]":
|
|
line = line.decode("utf-8")
|
|
msg, final_line = call.postprocess(line, stop=stop)
|
|
yield msg
|
|
if final_line:
|
|
break
|
|
self._try_raise(response)
|
|
|
|
return (r for r in out_gen())
|
|
|
|
####################################################################################
|
|
## Asynchronous streaming interface to allow multiple generations to happen at once.
|
|
|
|
async def get_req_astream(
|
|
self,
|
|
model: Optional[str] = None,
|
|
payload: dict = {},
|
|
invoke_url: Optional[str] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
) -> AsyncIterator:
|
|
invoke_url = self._get_invoke_url(model, invoke_url)
|
|
if payload.get("stream", True) is False:
|
|
payload = {**payload, "stream": True}
|
|
last_inputs = {
|
|
"url": invoke_url,
|
|
"headers": self.headers["stream"],
|
|
"json": payload,
|
|
}
|
|
async with self.get_asession_fn() as session:
|
|
async with session.post(**last_inputs) as response:
|
|
self._try_raise(response)
|
|
async for line in response.content.iter_any():
|
|
if line and line.strip() != b"data: [DONE]":
|
|
line = line.decode("utf-8")
|
|
msg, final_line = self.postprocess(line, stop=stop)
|
|
yield msg
|
|
if final_line:
|
|
break
|
|
|
|
|
|
class _NVIDIAClient(BaseModel):
|
|
"""
|
|
Higher-Level AI Foundation Model Function API Client with argument defaults.
|
|
Is subclassed by ChatNVIDIA to provide a simple LangChain interface.
|
|
"""
|
|
|
|
client: NVEModel = Field(NVEModel)
|
|
|
|
model: str = Field(..., description="Name of the model to invoke")
|
|
|
|
temperature: float = Field(0.2, le=1.0, gt=0.0)
|
|
top_p: float = Field(0.7, le=1.0, ge=0.0)
|
|
max_tokens: int = Field(1024, le=1024, ge=32)
|
|
|
|
####################################################################################
|
|
|
|
@root_validator(pre=True)
|
|
def validate_client(cls, values: Any) -> Any:
|
|
"""Validate and update client arguments, including API key and formatting"""
|
|
if not values.get("client"):
|
|
values["client"] = NVEModel(**values)
|
|
return values
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def available_functions(self) -> List[dict]:
|
|
"""Map the available functions that can be invoked."""
|
|
return self.client.available_functions
|
|
|
|
@property
|
|
def available_models(self) -> dict:
|
|
"""Map the available models that can be invoked."""
|
|
return self.client.available_models
|
|
|
|
def get_model_details(self, model: Optional[str] = None) -> dict:
|
|
"""Get more meta-details about a model retrieved by a given name"""
|
|
if model is None:
|
|
model = self.model
|
|
model_key = self.client._get_invoke_url(model).split("/")[-1]
|
|
known_fns = self.client.available_functions
|
|
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
|
|
return fn_spec
|
|
|
|
def get_generation(
|
|
self,
|
|
inputs: Sequence[Dict],
|
|
labels: Optional[dict] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
"""Call to client generate method with call scope"""
|
|
payload = self.get_payload(inputs=inputs, stream=False, labels=labels, **kwargs)
|
|
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
|
|
return out
|
|
|
|
def get_stream(
|
|
self,
|
|
inputs: Sequence[Dict],
|
|
labels: Optional[dict] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator:
|
|
"""Call to client stream method with call scope"""
|
|
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
|
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
|
|
|
|
def get_astream(
|
|
self,
|
|
inputs: Sequence[Dict],
|
|
labels: Optional[dict] = None,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator:
|
|
"""Call to client astream methods with call scope"""
|
|
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
|
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
|
|
|
|
def get_payload(
|
|
self, inputs: Sequence[Dict], labels: Optional[dict] = None, **kwargs: Any
|
|
) -> dict:
|
|
"""Generates payload for the _NVIDIAClient API to send to service."""
|
|
return {
|
|
**self.preprocess(inputs=inputs, labels=labels),
|
|
**kwargs,
|
|
}
|
|
|
|
def preprocess(self, inputs: Sequence[Dict], labels: Optional[dict] = None) -> dict:
|
|
"""Prepares a message or list of messages for the payload"""
|
|
messages = [self.prep_msg(m) for m in inputs]
|
|
if labels:
|
|
# (WFH) Labels are currently (?) always passed as an assistant
|
|
# suffix message, but this API seems less stable.
|
|
messages += [{"labels": labels, "role": "assistant"}]
|
|
return {"messages": messages}
|
|
|
|
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
|
|
"""Helper Method: Ensures a message is a dictionary with a role and content."""
|
|
if isinstance(msg, str):
|
|
# (WFH) this shouldn't ever be reached but leaving this here bcs
|
|
# it's a Chesterton's fence I'm unwilling to touch
|
|
return dict(role="user", content=msg)
|
|
if isinstance(msg, dict):
|
|
if msg.get("content", None) is None:
|
|
raise ValueError(f"Message {msg} has no content")
|
|
return msg
|
|
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
|