mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
3a2eb6e12b
Added noqa for existing prints. Can slowly remove / will prevent more being intro'd
668 lines
21 KiB
Python
668 lines
21 KiB
Python
"""A common module for NVIDIA Riva Runnables."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import pathlib
|
|
import queue
|
|
import tempfile
|
|
import threading
|
|
import wave
|
|
from enum import Enum
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
AsyncGenerator,
|
|
AsyncIterator,
|
|
Dict,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from langchain_core.messages import AnyMessage, BaseMessage
|
|
from langchain_core.prompt_values import PromptValue
|
|
from langchain_core.pydantic_v1 import (
|
|
AnyHttpUrl,
|
|
BaseModel,
|
|
Field,
|
|
parse_obj_as,
|
|
root_validator,
|
|
validator,
|
|
)
|
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
|
|
|
if TYPE_CHECKING:
|
|
import riva.client
|
|
import riva.client.proto.riva_asr_pb2 as rasr
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
_QUEUE_GET_TIMEOUT = 0.5
|
|
_MAX_TEXT_LENGTH = 400
|
|
_SENTENCE_TERMINATORS = ("\n", ".", "!", "?", "¡", "¿")
|
|
|
|
|
|
# COMMON utilities used by all Riva Runnables
|
|
def _import_riva_client() -> "riva.client":
|
|
"""Import the riva client and raise an error on failure."""
|
|
try:
|
|
# pylint: disable-next=import-outside-toplevel # this client library is optional
|
|
import riva.client
|
|
except ImportError as err:
|
|
raise ImportError(
|
|
"Could not import the NVIDIA Riva client library. "
|
|
"Please install it with `pip install nvidia-riva-client`."
|
|
) from err
|
|
return riva.client
|
|
|
|
|
|
class SentinelT: # pylint: disable=too-few-public-methods
|
|
"""An empty Sentinel type."""
|
|
|
|
|
|
HANGUP = SentinelT()
|
|
_TRANSFORM_END = SentinelT()
|
|
|
|
|
|
class RivaAudioEncoding(str, Enum):
|
|
"""An enum of the possible choices for Riva audio encoding.
|
|
|
|
The list of types exposed by the Riva GRPC Protobuf files can be found
|
|
with the following commands:
|
|
```python
|
|
import riva.client
|
|
print(riva.client.AudioEncoding.keys()) # noqa: T201
|
|
```
|
|
"""
|
|
|
|
ALAW = "ALAW"
|
|
ENCODING_UNSPECIFIED = "ENCODING_UNSPECIFIED"
|
|
FLAC = "FLAC"
|
|
LINEAR_PCM = "LINEAR_PCM"
|
|
MULAW = "MULAW"
|
|
OGGOPUS = "OGGOPUS"
|
|
|
|
@classmethod
|
|
def from_wave_format_code(cls, format_code: int) -> "RivaAudioEncoding":
|
|
"""Return the audio encoding specified by the format code in the wave file.
|
|
|
|
ref: https://mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
|
|
"""
|
|
try:
|
|
return {1: cls.LINEAR_PCM, 6: cls.ALAW, 7: cls.MULAW}[format_code]
|
|
except KeyError as err:
|
|
raise NotImplementedError(
|
|
"The following wave file format code is "
|
|
f"not supported by Riva: {format_code}"
|
|
) from err
|
|
|
|
@property
|
|
def riva_pb2(self) -> "riva.client.AudioEncoding":
|
|
"""Returns the Riva API object for the encoding."""
|
|
riva_client = _import_riva_client()
|
|
return getattr(riva_client.AudioEncoding, self)
|
|
|
|
|
|
class RivaAuthMixin(BaseModel):
|
|
"""Configuration for the authentication to a Riva service connection."""
|
|
|
|
url: Union[AnyHttpUrl, str] = Field(
|
|
AnyHttpUrl("http://localhost:50051", scheme="http"),
|
|
description="The full URL where the Riva service can be found.",
|
|
examples=["http://localhost:50051", "https://user@pass:riva.example.com"],
|
|
)
|
|
ssl_cert: Optional[str] = Field(
|
|
None,
|
|
description="A full path to the file where Riva's public ssl key can be read.",
|
|
)
|
|
|
|
@property
|
|
def auth(self) -> "riva.client.Auth":
|
|
"""Return a riva client auth object."""
|
|
riva_client = _import_riva_client()
|
|
url = cast(AnyHttpUrl, self.url)
|
|
use_ssl = url.scheme == "https" # pylint: disable=no-member # false positive
|
|
url_no_scheme = str(self.url).split("/")[2]
|
|
return riva_client.Auth(
|
|
ssl_cert=self.ssl_cert, use_ssl=use_ssl, uri=url_no_scheme
|
|
)
|
|
|
|
@validator("url", pre=True, allow_reuse=True)
|
|
@classmethod
|
|
def _validate_url(cls, val: Any) -> AnyHttpUrl:
|
|
"""Do some initial conversations for the URL before checking."""
|
|
if isinstance(val, str):
|
|
return cast(AnyHttpUrl, parse_obj_as(AnyHttpUrl, val))
|
|
return cast(AnyHttpUrl, val)
|
|
|
|
|
|
class RivaCommonConfigMixin(BaseModel):
|
|
"""A collection of common Riva settings."""
|
|
|
|
encoding: RivaAudioEncoding = Field(
|
|
default=RivaAudioEncoding.LINEAR_PCM,
|
|
description="The encoding on the audio stream.",
|
|
)
|
|
sample_rate_hertz: int = Field(
|
|
default=8000, description="The sample rate frequency of audio stream."
|
|
)
|
|
language_code: str = Field(
|
|
default="en-US",
|
|
description=(
|
|
"The [BCP-47 language code]"
|
|
"(https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for "
|
|
"the target language."
|
|
),
|
|
)
|
|
|
|
|
|
class _Event:
|
|
"""A combined event that is threadsafe and async safe."""
|
|
|
|
_event: threading.Event
|
|
_aevent: asyncio.Event
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the event."""
|
|
self._event = threading.Event()
|
|
self._aevent = asyncio.Event()
|
|
|
|
def set(self) -> None:
|
|
"""Set the event."""
|
|
self._event.set()
|
|
self._aevent.set()
|
|
|
|
def clear(self) -> None:
|
|
"""Set the event."""
|
|
self._event.clear()
|
|
self._aevent.clear()
|
|
|
|
def is_set(self) -> bool:
|
|
"""Indicate if the event is set."""
|
|
return self._event.is_set()
|
|
|
|
def wait(self) -> None:
|
|
"""Wait for the event to be set."""
|
|
self._event.wait()
|
|
|
|
async def async_wait(self) -> None:
|
|
"""Async wait for the event to be set."""
|
|
await self._aevent.wait()
|
|
|
|
|
|
def _mk_wave_file(
|
|
output_directory: Optional[str], sample_rate: float
|
|
) -> Tuple[Optional[str], Optional[wave.Wave_write]]:
|
|
"""Create a new wave file and return the wave write object and filename."""
|
|
if output_directory:
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="bx", suffix=".wav", delete=False, dir=output_directory
|
|
) as f:
|
|
wav_file_name = f.name
|
|
wav_file = wave.open(wav_file_name, "wb")
|
|
wav_file.setnchannels(1)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setframerate(sample_rate)
|
|
return (wav_file_name, wav_file)
|
|
return (None, None)
|
|
|
|
|
|
def _coerce_string(val: "TTSInputType") -> str:
|
|
"""Attempt to coerce the input value to a string.
|
|
|
|
This is particularly useful for converting LangChain message to strings.
|
|
"""
|
|
if isinstance(val, PromptValue):
|
|
return val.to_string()
|
|
if isinstance(val, BaseMessage):
|
|
return str(val.content)
|
|
return str(val)
|
|
|
|
|
|
def _process_chunks(inputs: Iterator["TTSInputType"]) -> Generator[str, None, None]:
|
|
"""Filter the input chunks are return strings ready for TTS."""
|
|
buffer = ""
|
|
for chunk in inputs:
|
|
chunk = _coerce_string(chunk)
|
|
|
|
# return the buffer if an end of sentence character is detected
|
|
for terminator in _SENTENCE_TERMINATORS:
|
|
while terminator in chunk:
|
|
last_sentence, chunk = chunk.split(terminator, 1)
|
|
yield buffer + last_sentence + terminator
|
|
buffer = ""
|
|
|
|
buffer += chunk
|
|
|
|
# return the buffer if is too long
|
|
if len(buffer) > _MAX_TEXT_LENGTH:
|
|
for idx in range(0, len(buffer), _MAX_TEXT_LENGTH):
|
|
yield buffer[idx : idx + 5]
|
|
buffer = ""
|
|
|
|
# return remaining buffer
|
|
if buffer:
|
|
yield buffer
|
|
|
|
|
|
# Riva AudioStream Type
|
|
StreamInputType = Union[bytes, SentinelT]
|
|
StreamOutputType = str
|
|
|
|
|
|
class AudioStream:
|
|
"""A message containing streaming audio."""
|
|
|
|
_put_lock: threading.Lock
|
|
_queue: queue.Queue
|
|
output: queue.Queue
|
|
hangup: _Event
|
|
user_talking: _Event
|
|
user_quiet: _Event
|
|
_worker: Optional[threading.Thread]
|
|
|
|
def __init__(self, maxsize: int = 0) -> None:
|
|
"""Initialize the queue."""
|
|
self._put_lock = threading.Lock()
|
|
self._queue = queue.Queue(maxsize=maxsize)
|
|
self.output = queue.Queue()
|
|
self.hangup = _Event()
|
|
self.user_quiet = _Event()
|
|
self.user_talking = _Event()
|
|
self._worker = None
|
|
|
|
def __iter__(self) -> Generator[bytes, None, None]:
|
|
"""Return an error."""
|
|
while True:
|
|
# get next item
|
|
try:
|
|
next_val = self._queue.get(True, _QUEUE_GET_TIMEOUT)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
# hangup when requested
|
|
if next_val == HANGUP:
|
|
break
|
|
|
|
# yield next item
|
|
yield next_val
|
|
self._queue.task_done()
|
|
|
|
async def __aiter__(self) -> AsyncIterator[StreamInputType]:
|
|
"""Iterate through all items in the queue until HANGUP."""
|
|
while True:
|
|
# get next item
|
|
try:
|
|
next_val = await asyncio.get_event_loop().run_in_executor(
|
|
None, self._queue.get, True, _QUEUE_GET_TIMEOUT
|
|
)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
# hangup when requested
|
|
if next_val == HANGUP:
|
|
break
|
|
|
|
# yield next item
|
|
yield next_val
|
|
self._queue.task_done()
|
|
|
|
@property
|
|
def hungup(self) -> bool:
|
|
"""Indicate if the audio stream has hungup."""
|
|
return self.hangup.is_set()
|
|
|
|
@property
|
|
def empty(self) -> bool:
|
|
"""Indicate in the input stream buffer is empty."""
|
|
return self._queue.empty()
|
|
|
|
@property
|
|
def complete(self) -> bool:
|
|
"""Indicate if the audio stream has hungup and been processed."""
|
|
input_done = self.hungup and self.empty
|
|
output_done = (
|
|
self._worker is not None
|
|
and not self._worker.is_alive()
|
|
and self.output.empty()
|
|
)
|
|
return input_done and output_done
|
|
|
|
@property
|
|
def running(self) -> bool:
|
|
"""Indicate if the ASR stream is running."""
|
|
if self._worker:
|
|
return self._worker.is_alive()
|
|
return False
|
|
|
|
def put(self, item: StreamInputType, timeout: Optional[int] = None) -> None:
|
|
"""Put a new item into the queue."""
|
|
with self._put_lock:
|
|
if self.hungup:
|
|
raise RuntimeError(
|
|
"The audio stream has already been hungup. Cannot put more data."
|
|
)
|
|
if item is HANGUP:
|
|
self.hangup.set()
|
|
self._queue.put(item, timeout=timeout)
|
|
|
|
async def aput(self, item: StreamInputType, timeout: Optional[int] = None) -> None:
|
|
"""Async put a new item into the queue."""
|
|
loop = asyncio.get_event_loop()
|
|
await asyncio.wait_for(loop.run_in_executor(None, self.put, item), timeout)
|
|
|
|
def close(self, timeout: Optional[int] = None) -> None:
|
|
"""Send the hangup signal."""
|
|
self.put(HANGUP, timeout)
|
|
|
|
async def aclose(self, timeout: Optional[int] = None) -> None:
|
|
"""Async send the hangup signal."""
|
|
await self.aput(HANGUP, timeout)
|
|
|
|
def register(self, responses: Iterator["rasr.StreamingRecognizeResponse"]) -> None:
|
|
"""Drain the responses from the provided iterator and put them into a queue."""
|
|
if self.running:
|
|
raise RuntimeError("An ASR instance has already been registered.")
|
|
|
|
has_started = threading.Barrier(2, timeout=5)
|
|
|
|
def worker() -> None:
|
|
"""Consume the ASR Generator."""
|
|
has_started.wait()
|
|
for response in responses:
|
|
if not response.results:
|
|
continue
|
|
|
|
for result in response.results:
|
|
if not result.alternatives:
|
|
continue
|
|
|
|
if result.is_final:
|
|
self.user_talking.clear()
|
|
self.user_quiet.set()
|
|
transcript = cast(str, result.alternatives[0].transcript)
|
|
self.output.put(transcript)
|
|
|
|
elif not self.user_talking.is_set():
|
|
self.user_talking.set()
|
|
self.user_quiet.clear()
|
|
|
|
self._worker = threading.Thread(target=worker)
|
|
self._worker.daemon = True
|
|
self._worker.start()
|
|
has_started.wait()
|
|
|
|
|
|
# RivaASR Runnable
|
|
ASRInputType = AudioStream
|
|
ASROutputType = str
|
|
|
|
|
|
class RivaASR(
|
|
RivaAuthMixin,
|
|
RivaCommonConfigMixin,
|
|
RunnableSerializable[ASRInputType, ASROutputType],
|
|
):
|
|
"""A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva."""
|
|
|
|
name: str = "nvidia_riva_asr"
|
|
description: str = (
|
|
"A Runnable for converting audio bytes to a string."
|
|
"This is useful for feeding an audio stream into a chain and"
|
|
"preprocessing that audio to create an LLM prompt."
|
|
)
|
|
|
|
# riva options
|
|
audio_channel_count: int = Field(
|
|
1, description="The number of audio channels in the input audio stream."
|
|
)
|
|
profanity_filter: bool = Field(
|
|
True,
|
|
description=(
|
|
"Controls whether or not Riva should attempt to filter "
|
|
"profanity out of the transcribed text."
|
|
),
|
|
)
|
|
enable_automatic_punctuation: bool = Field(
|
|
True,
|
|
description=(
|
|
"Controls whether Riva should attempt to correct "
|
|
"senetence puncuation in the transcribed text."
|
|
),
|
|
)
|
|
|
|
@root_validator(pre=True)
|
|
@classmethod
|
|
def _validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate the Python environment and input arguments."""
|
|
_ = _import_riva_client()
|
|
return values
|
|
|
|
@property
|
|
def config(self) -> "riva.client.StreamingRecognitionConfig":
|
|
"""Create and return the riva config object."""
|
|
riva_client = _import_riva_client()
|
|
return riva_client.StreamingRecognitionConfig(
|
|
interim_results=True,
|
|
config=riva_client.RecognitionConfig(
|
|
encoding=self.encoding,
|
|
sample_rate_hertz=self.sample_rate_hertz,
|
|
audio_channel_count=self.audio_channel_count,
|
|
max_alternatives=1,
|
|
profanity_filter=self.profanity_filter,
|
|
enable_automatic_punctuation=self.enable_automatic_punctuation,
|
|
language_code=self.language_code,
|
|
),
|
|
)
|
|
|
|
def _get_service(self) -> "riva.client.ASRService":
|
|
"""Connect to the riva service and return the a client object."""
|
|
riva_client = _import_riva_client()
|
|
try:
|
|
return riva_client.ASRService(self.auth)
|
|
except Exception as err:
|
|
raise ValueError(
|
|
"Error raised while connecting to the Riva ASR server."
|
|
) from err
|
|
|
|
def invoke(
|
|
self,
|
|
input: ASRInputType,
|
|
_: Optional[RunnableConfig] = None,
|
|
) -> ASROutputType:
|
|
"""Transcribe the audio bytes into a string with Riva."""
|
|
# create an output text generator with Riva
|
|
if not input.running:
|
|
service = self._get_service()
|
|
responses = service.streaming_response_generator(
|
|
audio_chunks=input,
|
|
streaming_config=self.config,
|
|
)
|
|
input.register(responses)
|
|
|
|
# return the first valid result
|
|
full_response: List[str] = []
|
|
while not input.complete:
|
|
with input.output.not_empty:
|
|
ready = input.output.not_empty.wait(0.1)
|
|
|
|
if ready:
|
|
while not input.output.empty():
|
|
try:
|
|
full_response += [input.output.get_nowait()]
|
|
except queue.Empty:
|
|
continue
|
|
input.output.task_done()
|
|
_LOGGER.debug("Riva ASR returning: %s", repr(full_response))
|
|
return " ".join(full_response).strip()
|
|
|
|
return ""
|
|
|
|
|
|
# RivaTTS Runnable
|
|
# pylint: disable-next=invalid-name
|
|
TTSInputType = Union[str, AnyMessage, PromptValue]
|
|
TTSOutputType = bytes
|
|
|
|
|
|
class RivaTTS(
|
|
RivaAuthMixin,
|
|
RivaCommonConfigMixin,
|
|
RunnableSerializable[TTSInputType, TTSOutputType],
|
|
):
|
|
"""A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva."""
|
|
|
|
name: str = "nvidia_riva_tts"
|
|
description: str = (
|
|
"A tool for converting text to speech."
|
|
"This is useful for converting LLM output into audio bytes."
|
|
)
|
|
|
|
# riva options
|
|
voice_name: str = Field(
|
|
"English-US.Female-1",
|
|
description=(
|
|
"The voice model in Riva to use for speech. "
|
|
"Pre-trained models are documented in "
|
|
"[the Riva documentation]"
|
|
"(https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html)."
|
|
),
|
|
)
|
|
output_directory: Optional[str] = Field(
|
|
None,
|
|
description=(
|
|
"The directory where all audio files should be saved. "
|
|
"A null value indicates that wave files should not be saved. "
|
|
"This is useful for debugging purposes."
|
|
),
|
|
)
|
|
|
|
@root_validator(pre=True)
|
|
@classmethod
|
|
def _validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate the Python environment and input arguments."""
|
|
_ = _import_riva_client()
|
|
return values
|
|
|
|
@validator("output_directory")
|
|
@classmethod
|
|
def _output_directory_validator(cls, v: str) -> str:
|
|
if v:
|
|
dirpath = pathlib.Path(v)
|
|
dirpath.mkdir(parents=True, exist_ok=True)
|
|
return str(dirpath.absolute())
|
|
return v
|
|
|
|
def _get_service(self) -> "riva.client.SpeechSynthesisService":
|
|
"""Connect to the riva service and return the a client object."""
|
|
riva_client = _import_riva_client()
|
|
try:
|
|
return riva_client.SpeechSynthesisService(self.auth)
|
|
except Exception as err:
|
|
raise ValueError(
|
|
"Error raised while connecting to the Riva TTS server."
|
|
) from err
|
|
|
|
def invoke(
|
|
self, input: TTSInputType, _: Union[RunnableConfig, None] = None
|
|
) -> TTSOutputType:
|
|
"""Perform TTS by taking a string and outputting the entire audio file."""
|
|
return b"".join(self.transform(iter([input])))
|
|
|
|
def transform(
|
|
self,
|
|
input: Iterator[TTSInputType],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Optional[Any],
|
|
) -> Iterator[TTSOutputType]:
|
|
"""Perform TTS by taking a stream of characters and streaming output bytes."""
|
|
service = self._get_service()
|
|
|
|
# create an output wave file
|
|
wav_file_name, wav_file = _mk_wave_file(
|
|
self.output_directory, self.sample_rate_hertz
|
|
)
|
|
|
|
# split the input text and perform tts
|
|
for chunk in _process_chunks(input):
|
|
_LOGGER.debug("Riva TTS chunk: %s", chunk)
|
|
|
|
# start riva tts streaming
|
|
responses = service.synthesize_online(
|
|
text=chunk,
|
|
voice_name=self.voice_name,
|
|
language_code=self.language_code,
|
|
encoding=self.encoding.riva_pb2,
|
|
sample_rate_hz=self.sample_rate_hertz,
|
|
)
|
|
|
|
# stream audio bytes out
|
|
for resp in responses:
|
|
audio = cast(bytes, resp.audio)
|
|
if wav_file:
|
|
wav_file.writeframesraw(audio)
|
|
yield audio
|
|
|
|
# close the wave file when we are done
|
|
if wav_file:
|
|
wav_file.close()
|
|
_LOGGER.debug("Riva TTS wrote file: %s", wav_file_name)
|
|
|
|
async def atransform(
|
|
self,
|
|
input: AsyncIterator[TTSInputType],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Optional[Any],
|
|
) -> AsyncGenerator[TTSOutputType, None]:
|
|
"""Intercept async transforms and route them to the synchronous transform."""
|
|
loop = asyncio.get_running_loop()
|
|
input_queue: queue.Queue = queue.Queue()
|
|
out_queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
async def _producer() -> None:
|
|
"""Produce input into the input queue."""
|
|
async for val in input:
|
|
input_queue.put_nowait(val)
|
|
input_queue.put_nowait(_TRANSFORM_END)
|
|
|
|
def _input_iterator() -> Iterator[TTSInputType]:
|
|
"""Iterate over the input_queue."""
|
|
while True:
|
|
try:
|
|
val = input_queue.get(timeout=0.5)
|
|
except queue.Empty:
|
|
continue
|
|
if val == _TRANSFORM_END:
|
|
break
|
|
yield val
|
|
|
|
def _consumer() -> None:
|
|
"""Consume the input with transform."""
|
|
for val in self.transform(_input_iterator()):
|
|
out_queue.put_nowait(val)
|
|
out_queue.put_nowait(_TRANSFORM_END)
|
|
|
|
async def _consumer_coro() -> None:
|
|
"""Coroutine that wraps the consumer."""
|
|
await loop.run_in_executor(None, _consumer)
|
|
|
|
producer = loop.create_task(_producer())
|
|
consumer = loop.create_task(_consumer_coro())
|
|
|
|
while True:
|
|
try:
|
|
val = await asyncio.wait_for(out_queue.get(), 0.5)
|
|
except asyncio.exceptions.TimeoutError:
|
|
continue
|
|
out_queue.task_done()
|
|
|
|
if val is _TRANSFORM_END:
|
|
break
|
|
yield val
|
|
|
|
await producer
|
|
await consumer
|