|
|
@ -377,14 +377,14 @@ class StreamingResponseGenerator(queue.Queue):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
client: grpcclient.InferenceServerClient,
|
|
|
|
llm: TritonTensorRTLLM,
|
|
|
|
request_id: str,
|
|
|
|
request_id: str,
|
|
|
|
force_batch: bool,
|
|
|
|
force_batch: bool,
|
|
|
|
stop_words: Sequence[str],
|
|
|
|
stop_words: Sequence[str],
|
|
|
|
) -> None:
|
|
|
|
) -> None:
|
|
|
|
"""Instantiate the generator class."""
|
|
|
|
"""Instantiate the generator class."""
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.client = client
|
|
|
|
self.llm = llm
|
|
|
|
self.request_id = request_id
|
|
|
|
self.request_id = request_id
|
|
|
|
self._batch = force_batch
|
|
|
|
self._batch = force_batch
|
|
|
|
self._stop_words = stop_words
|
|
|
|
self._stop_words = stop_words
|
|
|
@ -397,8 +397,8 @@ class StreamingResponseGenerator(queue.Queue):
|
|
|
|
"""Return the next retrieved token."""
|
|
|
|
"""Return the next retrieved token."""
|
|
|
|
val = self.get()
|
|
|
|
val = self.get()
|
|
|
|
if val is None or val in self._stop_words:
|
|
|
|
if val is None or val in self._stop_words:
|
|
|
|
self.client.stop_stream(
|
|
|
|
self.llm.stop_stream(
|
|
|
|
"tensorrt_llm", self.request_id, signal=not self._batch
|
|
|
|
self.llm.model_name, self.request_id, signal=not self._batch
|
|
|
|
)
|
|
|
|
)
|
|
|
|
raise StopIteration()
|
|
|
|
raise StopIteration()
|
|
|
|
return val
|
|
|
|
return val
|
|
|
|