From 2145636f1d7af49cdecdde139377cc6c37e6765f Mon Sep 17 00:00:00 2001 From: Mikhail Khludnev Date: Mon, 5 Feb 2024 21:45:06 +0300 Subject: [PATCH] Nvidia trt model name for stop_stream() (#16997) just removing some legacy leftover. --- libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py index f56eeb9a1d..36e1e6e5ca 100644 --- a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py +++ b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py @@ -377,14 +377,14 @@ class StreamingResponseGenerator(queue.Queue): def __init__( self, - client: grpcclient.InferenceServerClient, + llm: TritonTensorRTLLM, request_id: str, force_batch: bool, stop_words: Sequence[str], ) -> None: """Instantiate the generator class.""" super().__init__() - self.client = client + self.llm = llm self.request_id = request_id self._batch = force_batch self._stop_words = stop_words @@ -397,8 +397,8 @@ class StreamingResponseGenerator(queue.Queue): """Return the next retrieved token.""" val = self.get() if val is None or val in self._stop_words: - self.client.stop_stream( - "tensorrt_llm", self.request_id, signal=not self._batch + self.llm.stop_stream( + self.llm.model_name, self.request_id, signal=not self._batch ) raise StopIteration() return val