langchain/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py
William FH 583696732c
[Partner] NVIDIA TRT Package (#14733)
Simplify #13976 and add as a separate package.

- [] Add README
- [X] Add doc notebook
- [X] Add simple LLM integration

---------

Co-authored-by: Jeremy Dyer <jdye64@gmail.com>
2023-12-18 19:08:25 -08:00

405 lines
14 KiB
Python

from __future__ import annotations
import json
import queue
import random
import time
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
import google.protobuf.json_format
import numpy as np
import tritonclient.grpc as grpcclient
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from tritonclient.grpc.service_pb2 import ModelInferResponse
from tritonclient.utils import np_to_triton_dtype
class TritonTensorRTError(Exception):
"""Base exception for TritonTensorRT."""
class TritonTensorRTRuntimeError(TritonTensorRTError, RuntimeError):
"""Runtime error for TritonTensorRT."""
class TritonTensorRTLLM(BaseLLM):
"""TRTLLM triton models.
Arguments:
server_url: (str) The URL of the Triton inference server to use.
model_name: (str) The name of the Triton TRT model to use.
temperature: (str) Temperature to use for sampling
top_p: (float) The top-p value to use for sampling
top_k: (float) The top k values use for sampling
beam_width: (int) Last n number of tokens to penalize
repetition_penalty: (int) Last n number of tokens to penalize
length_penalty: (float) The penalty to apply repeated tokens
tokens: (int) The maximum number of tokens to generate.
client: The client object used to communicate with the inference server
Example:
.. code-block:: python
from langchain_nvidia_trt import TritonTensorRTLLM
model = TritonTensorRTLLM()
"""
server_url: Optional[str] = Field(None, alias="server_url")
model_name: str = Field(
..., description="The name of the model to use, such as 'ensemble'."
)
## Optional args for the model
temperature: float = 1.0
top_p: float = 0
top_k: int = 1
tokens: int = 100
beam_width: int = 1
repetition_penalty: float = 1.0
length_penalty: float = 1.0
client: grpcclient.InferenceServerClient
stop: List[str] = Field(
default_factory=lambda: ["</s>"], description="Stop tokens."
)
seed: int = Field(42, description="The seed to use for random generation.")
load_model: bool = Field(
True,
description="Request the inference server to load the specified model.\
Certain Triton configurations do not allow for this operation.",
)
def __del__(self):
"""Ensure the client streaming connection is properly shutdown"""
self.client.close()
@root_validator(pre=True, allow_reuse=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that python package exists in environment."""
if not values.get("client"):
values["client"] = grpcclient.InferenceServerClient(values["server_url"])
return values
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "nvidia-trt-llm"
@property
def _model_default_parameters(self) -> Dict[str, Any]:
return {
"tokens": self.tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"length_penalty": self.length_penalty,
"beam_width": self.beam_width,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get all the identifying parameters."""
return {
"server_url": self.server_url,
"model_name": self.model_name,
**self._model_default_parameters,
}
def _get_invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
return {**self._model_default_parameters, **kwargs}
def get_model_list(self) -> List[str]:
"""Get a list of models loaded in the triton server."""
res = self.client.get_model_repository_index(as_json=True)
return [model["name"] for model in res["models"]]
def _load_model(self, model_name: str, timeout: int = 1000) -> None:
"""Load a model into the server."""
if self.client.is_model_ready(model_name):
return
self.client.load_model(model_name)
t0 = time.perf_counter()
t1 = t0
while not self.client.is_model_ready(model_name) and t1 - t0 < timeout:
t1 = time.perf_counter()
if not self.client.is_model_ready(model_name):
raise TritonTensorRTRuntimeError(
f"Failed to load {model_name} on Triton in {timeout}s"
)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
self._load_model(self.model_name)
invocation_params = self._get_invocation_params(**kwargs)
stop_words = stop if stop is not None else self.stop
generations = []
# TODO: We should handle the native batching instead.
for prompt in prompts:
invoc_params = {**invocation_params, "prompt": [[prompt]]}
result: str = self._request(
self.model_name,
stop=stop_words,
**invoc_params,
)
generations.append([Generation(text=result, generation_info={})])
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
self._load_model(self.model_name)
invocation_params = self._get_invocation_params(**kwargs, prompt=[[prompt]])
stop_words = stop if stop is not None else self.stop
inputs = self._generate_inputs(stream=True, **invocation_params)
outputs = self._generate_outputs()
result_queue = self._invoke_triton(self.model_name, inputs, outputs, stop_words)
for token in result_queue:
yield GenerationChunk(text=token)
if run_manager:
run_manager.on_llm_new_token(token)
self.client.stop_stream()
##### BELOW ARE METHODS PREVIOUSLY ONLY IN THE GRPC CLIENT
def _request(
self,
model_name: str,
prompt: Sequence[Sequence[str]],
stop: Optional[List[str]] = None,
**params: Any,
) -> str:
"""Request inferencing from the triton server."""
# create model inputs and outputs
inputs = self._generate_inputs(stream=False, prompt=prompt, **params)
outputs = self._generate_outputs()
result_queue = self._invoke_triton(self.model_name, inputs, outputs, stop)
result_str = ""
for token in result_queue:
result_str += token
self.client.stop_stream()
return result_str
def _invoke_triton(self, model_name, inputs, outputs, stop_words):
if not self.client.is_model_ready(model_name):
raise RuntimeError("Cannot request streaming, model is not loaded")
request_id = str(random.randint(1, 9999999)) # nosec
result_queue = StreamingResponseGenerator(
self,
request_id,
force_batch=False,
stop_words=stop_words,
)
self.client.start_stream(
callback=partial(
self._stream_callback,
result_queue,
stop_words=stop_words,
)
)
# Even though this request may not be a streaming request certain configurations
# in Triton prevent the GRPC server from accepting none streaming connections.
# Therefore we call the streaming API and combine the streamed results.
self.client.async_stream_infer(
model_name=model_name,
inputs=inputs,
outputs=outputs,
request_id=request_id,
)
return result_queue
def _generate_outputs(
self,
) -> List[grpcclient.InferRequestedOutput]:
"""Generate the expected output structure."""
return [grpcclient.InferRequestedOutput("text_output")]
def _prepare_tensor(
self, name: str, input_data: np.ndarray
) -> grpcclient.InferInput:
"""Prepare an input data structure."""
t = grpcclient.InferInput(
name, input_data.shape, np_to_triton_dtype(input_data.dtype)
)
t.set_data_from_numpy(input_data)
return t
def _generate_inputs(
self,
prompt: Sequence[Sequence[str]],
tokens: int = 300,
temperature: float = 1.0,
top_k: float = 1,
top_p: float = 0,
beam_width: int = 1,
repetition_penalty: float = 1,
length_penalty: float = 1.0,
stream: bool = True,
) -> List[grpcclient.InferRequestedOutput]:
"""Create the input for the triton inference server."""
query = np.array(prompt).astype(object)
request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))
runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))
runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))
temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))
len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))
repetition_penalty_array = (
np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))
)
random_seed = np.array([self.seed]).astype(np.uint64).reshape((1, -1))
beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))
streaming_data = np.array([[stream]], dtype=bool)
inputs = [
self._prepare_tensor("text_input", query),
self._prepare_tensor("max_tokens", request_output_len),
self._prepare_tensor("top_k", runtime_top_k),
self._prepare_tensor("top_p", runtime_top_p),
self._prepare_tensor("temperature", temperature_array),
self._prepare_tensor("length_penalty", len_penalty),
self._prepare_tensor("repetition_penalty", repetition_penalty_array),
self._prepare_tensor("random_seed", random_seed),
self._prepare_tensor("beam_width", beam_width_array),
self._prepare_tensor("stream", streaming_data),
]
return inputs
def _send_stop_signals(self, model_name: str, request_id: str) -> None:
"""Send the stop signal to the Triton Inference server."""
stop_inputs = self._generate_stop_signals()
self.client.async_stream_infer(
model_name,
stop_inputs,
request_id=request_id,
parameters={"Streaming": True},
)
def _generate_stop_signals(
self,
) -> List[grpcclient.InferInput]:
"""Generate the signal to stop the stream."""
inputs = [
grpcclient.InferInput("input_ids", [1, 1], "INT32"),
grpcclient.InferInput("input_lengths", [1, 1], "INT32"),
grpcclient.InferInput("request_output_len", [1, 1], "UINT32"),
grpcclient.InferInput("stop", [1, 1], "BOOL"),
]
inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))
inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))
inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))
inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))
return inputs
@staticmethod
def _process_result(result: Dict[str, str]) -> str:
"""Post-process the result from the server."""
message = ModelInferResponse()
google.protobuf.json_format.Parse(json.dumps(result), message)
infer_result = grpcclient.InferResult(message)
np_res = infer_result.as_numpy("text_output")
generated_text = ""
if np_res is not None:
generated_text = "".join([token.decode() for token in np_res])
return generated_text
def _stream_callback(
self,
result_queue: queue.Queue[Union[Optional[Dict[str, str]], str]],
result: grpcclient.InferResult,
error: str,
stop_words: List[str],
) -> None:
"""Add streamed result to queue."""
if error:
result_queue.put(error)
else:
response_raw: dict = result.get_response(as_json=True)
# TODO: Check the response is a map rather than a string
if "outputs" in response_raw:
# the very last response might have no output, just the final flag
response = self._process_result(response_raw)
if response in stop_words:
result_queue.put(None)
else:
result_queue.put(response)
if response_raw["parameters"]["triton_final_response"]["bool_param"]:
# end of the generation
result_queue.put(None)
def stop_stream(
self, model_name: str, request_id: str, signal: bool = True
) -> None:
"""Close the streaming connection."""
if signal:
self._send_stop_signals(model_name, request_id)
self.client.stop_stream()
class StreamingResponseGenerator(queue.Queue):
"""A Generator that provides the inference results from an LLM."""
def __init__(
self,
client: grpcclient.InferenceServerClient,
request_id: str,
force_batch: bool,
stop_words: Sequence[str],
) -> None:
"""Instantiate the generator class."""
super().__init__()
self.client = client
self.request_id = request_id
self._batch = force_batch
self._stop_words = stop_words
def __iter__(self) -> StreamingResponseGenerator:
"""Return self as a generator."""
return self
def __next__(self) -> str:
"""Return the next retrieved token."""
val = self.get()
if val is None or val in self._stop_words:
self.client.stop_stream(
"tensorrt_llm", self.request_id, signal=not self._batch
)
raise StopIteration()
return val