diff --git a/docs/docs/integrations/tools/nvidia_riva.ipynb b/docs/docs/integrations/tools/nvidia_riva.ipynb
new file mode 100644
index 0000000000..a4cf2f299a
--- /dev/null
+++ b/docs/docs/integrations/tools/nvidia_riva.ipynb
@@ -0,0 +1,706 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "cc6caafa",
+ "metadata": {
+ "id": "cc6caafa"
+ },
+ "source": [
+ "# NVIDIA Riva: ASR and TTS\n",
+ "\n",
+ "## NVIDIA Riva\n",
+ "[NVIDIA Riva](https://www.nvidia.com/en-us/ai-data-science/products/riva/) is a GPU-accelerated multilingual speech and translation AI software development kit for building fully customizable, real-time conversational AI pipelines—including automatic speech recognition (ASR), text-to-speech (TTS), and neural machine translation (NMT) applications—that can be deployed in clouds, in data centers, at the edge, or on embedded devices.\n",
+ "\n",
+ "The Riva Speech API server exposes a simple API for performing speech recognition, speech synthesis, and a variety of natural language processing inferences and is integrated into LangChain for ASR and TTS. See instructions on how to [setup a Riva Speech API](#3-setup) server below. \n",
+ "\n",
+ "## Integrating NVIDIA Riva to LangChain Chains\n",
+ "The `NVIDIARivaASR`, `NVIDIARivaTTS` utility runnables are LangChain runnables that integrate [NVIDIA Riva](https://www.nvidia.com/en-us/ai-data-science/products/riva/) into LCEL chains for Automatic Speech Recognition (ASR) and Text To Speech (TTS).\n",
+ "\n",
+ "This example goes over how to use these LangChain runnables to:\n",
+ "1. Accept streamed audio,\n",
+ "2. convert the audio to text, \n",
+ "3. send the text to an LLM, \n",
+ "4. stream a textual LLM response, and\n",
+ "5. convert the response to streamed human-sounding audio. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b603439f",
+ "metadata": {},
+ "source": [
+ "## 1. NVIDIA Riva Runnables\n",
+ "There are 2 Riva Runnables:\n",
+ "\n",
+ "a. **RivaASR**: Converts audio bytes into text for an LLM using NVIDIA Riva. \n",
+ "\n",
+ "b. **RivaTTS**: Converts text into audio bytes using NVIDIA Riva.\n",
+ "\n",
+ "### a. RivaASR\n",
+ "The [**RivaASR**](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/utilities/nvidia_riva.py#L404) runnable converts audio bytes into a string for an LLM using NVIDIA Riva. \n",
+ "\n",
+ "It's useful for sending an audio stream (a message containing streaming audio) into a chain and preprocessing that audio by converting it to a string to create an LLM prompt. \n",
+ "\n",
+ "```\n",
+ "ASRInputType = AudioStream # the AudioStream type is a custom type for a message queue containing streaming audio\n",
+ "ASROutputType = str\n",
+ "\n",
+ "class RivaASR(\n",
+ " RivaAuthMixin,\n",
+ " RivaCommonConfigMixin,\n",
+ " RunnableSerializable[ASRInputType, ASROutputType],\n",
+ "):\n",
+ " \"\"\"A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva.\"\"\"\n",
+ "\n",
+ " name: str = \"nvidia_riva_asr\"\n",
+ " description: str = (\n",
+ " \"A Runnable for converting audio bytes to a string.\"\n",
+ " \"This is useful for feeding an audio stream into a chain and\"\n",
+ " \"preprocessing that audio to create an LLM prompt.\"\n",
+ " )\n",
+ "\n",
+ " # riva options\n",
+ " audio_channel_count: int = Field(\n",
+ " 1, description=\"The number of audio channels in the input audio stream.\"\n",
+ " )\n",
+ " profanity_filter: bool = Field(\n",
+ " True,\n",
+ " description=(\n",
+ " \"Controls whether or not Riva should attempt to filter \"\n",
+ " \"profanity out of the transcribed text.\"\n",
+ " ),\n",
+ " )\n",
+ " enable_automatic_punctuation: bool = Field(\n",
+ " True,\n",
+ " description=(\n",
+ " \"Controls whether Riva should attempt to correct \"\n",
+ " \"senetence puncuation in the transcribed text.\"\n",
+ " ),\n",
+ " )\n",
+ "```\n",
+ "\n",
+ "When this runnable is called on an input, it takes an input audio stream that acts as a queue and concatenates transcription as chunks are returned.After a response is fully generated, a string is returned. \n",
+ "* Note that since the LLM requires a full query the ASR is concatenated and not streamed in token-by-token.\n",
+ "\n",
+ "\n",
+ "### b. RivaTTS\n",
+ "The [**RivaTTS**](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/utilities/nvidia_riva.py#L511) runnable converts text output to audio bytes. \n",
+ "\n",
+ "It's useful for processing the streamed textual response from an LLM by converting the text to audio bytes. These audio bytes sound like a natural human voice to be played back to the user. \n",
+ "\n",
+ "```\n",
+ "TTSInputType = Union[str, AnyMessage, PromptValue]\n",
+ "TTSOutputType = byte\n",
+ "\n",
+ "class RivaTTS(\n",
+ " RivaAuthMixin,\n",
+ " RivaCommonConfigMixin,\n",
+ " RunnableSerializable[TTSInputType, TTSOutputType],\n",
+ "):\n",
+ " \"\"\"A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva.\"\"\"\n",
+ "\n",
+ " name: str = \"nvidia_riva_tts\"\n",
+ " description: str = (\n",
+ " \"A tool for converting text to speech.\"\n",
+ " \"This is useful for converting LLM output into audio bytes.\"\n",
+ " )\n",
+ "\n",
+ " # riva options\n",
+ " voice_name: str = Field(\n",
+ " \"English-US.Female-1\",\n",
+ " description=(\n",
+ " \"The voice model in Riva to use for speech. \"\n",
+ " \"Pre-trained models are documented in \"\n",
+ " \"[the Riva documentation]\"\n",
+ " \"(https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html).\"\n",
+ " ),\n",
+ " )\n",
+ " output_directory: Optional[str] = Field(\n",
+ " None,\n",
+ " description=(\n",
+ " \"The directory where all audio files should be saved. \"\n",
+ " \"A null value indicates that wave files should not be saved. \"\n",
+ " \"This is useful for debugging purposes.\"\n",
+ " ),\n",
+ "```\n",
+ "\n",
+ "When this runnable is called on an input, it takes iterable text chunks and streams them into output audio bytes that are either written to a `.wav` file or played out loud."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f2be90a9",
+ "metadata": {},
+ "source": [
+ "## 2. Installation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1ef87a40",
+ "metadata": {},
+ "source": [
+ "The NVIDIA Riva client library must be installed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "70410821",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install --upgrade --quiet nvidia-riva-client"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ccff689e",
+ "metadata": {
+ "id": "ccff689e"
+ },
+ "source": [
+ "## 3. Setup\n",
+ "\n",
+ "**To get started with NVIDIA Riva:**\n",
+ "\n",
+ "1. Follow the Riva Quick Start setup instructions for [Local Deployment Using Quick Start Scripts](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html#local-deployment-using-quick-start-scripts)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57b6741b",
+ "metadata": {},
+ "source": [
+ "## 4. Import and Inspect Runnables\n",
+ "Import the RivaASR and RivaTTS runnables and inspect their schemas to understand their fields. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "2d6fa641",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "from langchain_community.utilities.nvidia_riva import (\n",
+ " RivaASR,\n",
+ " RivaTTS,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0e6dd656",
+ "metadata": {},
+ "source": [
+ "Let's view the schemas."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "69460762",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"title\": \"RivaASR\",\n",
+ " \"description\": \"A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva.\",\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"name\": {\n",
+ " \"title\": \"Name\",\n",
+ " \"default\": \"nvidia_riva_asr\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"encoding\": {\n",
+ " \"description\": \"The encoding on the audio stream.\",\n",
+ " \"default\": \"LINEAR_PCM\",\n",
+ " \"allOf\": [\n",
+ " {\n",
+ " \"$ref\": \"#/definitions/RivaAudioEncoding\"\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " \"sample_rate_hertz\": {\n",
+ " \"title\": \"Sample Rate Hertz\",\n",
+ " \"description\": \"The sample rate frequency of audio stream.\",\n",
+ " \"default\": 8000,\n",
+ " \"type\": \"integer\"\n",
+ " },\n",
+ " \"language_code\": {\n",
+ " \"title\": \"Language Code\",\n",
+ " \"description\": \"The [BCP-47 language code](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for the target language.\",\n",
+ " \"default\": \"en-US\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"url\": {\n",
+ " \"title\": \"Url\",\n",
+ " \"description\": \"The full URL where the Riva service can be found.\",\n",
+ " \"default\": \"http://localhost:50051\",\n",
+ " \"examples\": [\n",
+ " \"http://localhost:50051\",\n",
+ " \"https://user@pass:riva.example.com\"\n",
+ " ],\n",
+ " \"anyOf\": [\n",
+ " {\n",
+ " \"type\": \"string\",\n",
+ " \"minLength\": 1,\n",
+ " \"maxLength\": 65536,\n",
+ " \"format\": \"uri\"\n",
+ " },\n",
+ " {\n",
+ " \"type\": \"string\"\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " \"ssl_cert\": {\n",
+ " \"title\": \"Ssl Cert\",\n",
+ " \"description\": \"A full path to the file where Riva's public ssl key can be read.\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"description\": {\n",
+ " \"title\": \"Description\",\n",
+ " \"default\": \"A Runnable for converting audio bytes to a string.This is useful for feeding an audio stream into a chain andpreprocessing that audio to create an LLM prompt.\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"audio_channel_count\": {\n",
+ " \"title\": \"Audio Channel Count\",\n",
+ " \"description\": \"The number of audio channels in the input audio stream.\",\n",
+ " \"default\": 1,\n",
+ " \"type\": \"integer\"\n",
+ " },\n",
+ " \"profanity_filter\": {\n",
+ " \"title\": \"Profanity Filter\",\n",
+ " \"description\": \"Controls whether or not Riva should attempt to filter profanity out of the transcribed text.\",\n",
+ " \"default\": true,\n",
+ " \"type\": \"boolean\"\n",
+ " },\n",
+ " \"enable_automatic_punctuation\": {\n",
+ " \"title\": \"Enable Automatic Punctuation\",\n",
+ " \"description\": \"Controls whether Riva should attempt to correct senetence puncuation in the transcribed text.\",\n",
+ " \"default\": true,\n",
+ " \"type\": \"boolean\"\n",
+ " }\n",
+ " },\n",
+ " \"definitions\": {\n",
+ " \"RivaAudioEncoding\": {\n",
+ " \"title\": \"RivaAudioEncoding\",\n",
+ " \"description\": \"An enum of the possible choices for Riva audio encoding.\\n\\nThe list of types exposed by the Riva GRPC Protobuf files can be found\\nwith the following commands:\\n```python\\nimport riva.client\\nprint(riva.client.AudioEncoding.keys()) # noqa: T201\\n```\",\n",
+ " \"enum\": [\n",
+ " \"ALAW\",\n",
+ " \"ENCODING_UNSPECIFIED\",\n",
+ " \"FLAC\",\n",
+ " \"LINEAR_PCM\",\n",
+ " \"MULAW\",\n",
+ " \"OGGOPUS\"\n",
+ " ],\n",
+ " \"type\": \"string\"\n",
+ " }\n",
+ " }\n",
+ "}\n",
+ "{\n",
+ " \"title\": \"RivaTTS\",\n",
+ " \"description\": \"A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva.\",\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"name\": {\n",
+ " \"title\": \"Name\",\n",
+ " \"default\": \"nvidia_riva_tts\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"encoding\": {\n",
+ " \"description\": \"The encoding on the audio stream.\",\n",
+ " \"default\": \"LINEAR_PCM\",\n",
+ " \"allOf\": [\n",
+ " {\n",
+ " \"$ref\": \"#/definitions/RivaAudioEncoding\"\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " \"sample_rate_hertz\": {\n",
+ " \"title\": \"Sample Rate Hertz\",\n",
+ " \"description\": \"The sample rate frequency of audio stream.\",\n",
+ " \"default\": 8000,\n",
+ " \"type\": \"integer\"\n",
+ " },\n",
+ " \"language_code\": {\n",
+ " \"title\": \"Language Code\",\n",
+ " \"description\": \"The [BCP-47 language code](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for the target language.\",\n",
+ " \"default\": \"en-US\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"url\": {\n",
+ " \"title\": \"Url\",\n",
+ " \"description\": \"The full URL where the Riva service can be found.\",\n",
+ " \"default\": \"http://localhost:50051\",\n",
+ " \"examples\": [\n",
+ " \"http://localhost:50051\",\n",
+ " \"https://user@pass:riva.example.com\"\n",
+ " ],\n",
+ " \"anyOf\": [\n",
+ " {\n",
+ " \"type\": \"string\",\n",
+ " \"minLength\": 1,\n",
+ " \"maxLength\": 65536,\n",
+ " \"format\": \"uri\"\n",
+ " },\n",
+ " {\n",
+ " \"type\": \"string\"\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " \"ssl_cert\": {\n",
+ " \"title\": \"Ssl Cert\",\n",
+ " \"description\": \"A full path to the file where Riva's public ssl key can be read.\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"description\": {\n",
+ " \"title\": \"Description\",\n",
+ " \"default\": \"A tool for converting text to speech.This is useful for converting LLM output into audio bytes.\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"voice_name\": {\n",
+ " \"title\": \"Voice Name\",\n",
+ " \"description\": \"The voice model in Riva to use for speech. Pre-trained models are documented in [the Riva documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html).\",\n",
+ " \"default\": \"English-US.Female-1\",\n",
+ " \"type\": \"string\"\n",
+ " },\n",
+ " \"output_directory\": {\n",
+ " \"title\": \"Output Directory\",\n",
+ " \"description\": \"The directory where all audio files should be saved. A null value indicates that wave files should not be saved. This is useful for debugging purposes.\",\n",
+ " \"type\": \"string\"\n",
+ " }\n",
+ " },\n",
+ " \"definitions\": {\n",
+ " \"RivaAudioEncoding\": {\n",
+ " \"title\": \"RivaAudioEncoding\",\n",
+ " \"description\": \"An enum of the possible choices for Riva audio encoding.\\n\\nThe list of types exposed by the Riva GRPC Protobuf files can be found\\nwith the following commands:\\n```python\\nimport riva.client\\nprint(riva.client.AudioEncoding.keys()) # noqa: T201\\n```\",\n",
+ " \"enum\": [\n",
+ " \"ALAW\",\n",
+ " \"ENCODING_UNSPECIFIED\",\n",
+ " \"FLAC\",\n",
+ " \"LINEAR_PCM\",\n",
+ " \"MULAW\",\n",
+ " \"OGGOPUS\"\n",
+ " ],\n",
+ " \"type\": \"string\"\n",
+ " }\n",
+ " }\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(json.dumps(RivaASR.schema(), indent=2))\n",
+ "print(json.dumps(RivaTTS.schema(), indent=2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f128f27",
+ "metadata": {},
+ "source": [
+ "## 5. Declare Riva ASR and Riva TTS Runnables\n",
+ "\n",
+ "For this example, a single-channel audio file (mulaw format, so `.wav`) is used.\n",
+ "\n",
+ "You will need a Riva speech server setup, so if you don't have a Riva speech server, go to [Setup](#3-setup).\n",
+ "\n",
+ "### a. Set Audio Parameters\n",
+ "Some parameters of audio can be inferred by the mulaw file, but others are set explicitly.\n",
+ "\n",
+ "Replace `audio_file` with the path of your audio file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "5c75995a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pywav # pywav is used instead of built-in wave because of mulaw support\n",
+ "from langchain_community.utilities.nvidia_riva import RivaAudioEncoding\n",
+ "\n",
+ "audio_file = \"./audio_files/en-US_sample2.wav\"\n",
+ "wav_file = pywav.WavRead(audio_file)\n",
+ "audio_data = wav_file.getdata()\n",
+ "audio_encoding = RivaAudioEncoding.from_wave_format_code(wav_file.getaudioformat())\n",
+ "sample_rate = wav_file.getsamplerate()\n",
+ "delay_time = 1 / 4\n",
+ "chunk_size = int(sample_rate * delay_time)\n",
+ "delay_time = 1 / 8\n",
+ "num_channels = wav_file.getnumofchannels()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "a3b29f36",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import IPython\n",
+ "\n",
+ "IPython.display.Audio(audio_file)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fb294e19",
+ "metadata": {},
+ "source": [
+ "### b. Set the Speech Server and Declare Riva LangChain Runnables\n",
+ "\n",
+ "Be sure to set `RIVA_SPEECH_URL` to be the URI of your Riva speech server.\n",
+ "\n",
+ "The runnables act as clients to the speech server. Many of the fields set in this example are configured based on the sample audio data. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "cf1108af",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "RIVA_SPEECH_URL = \"http://localhost:50051/\"\n",
+ "\n",
+ "riva_asr = RivaASR(\n",
+ " url=RIVA_SPEECH_URL, # the location of the Riva ASR server\n",
+ " encoding=audio_encoding,\n",
+ " audio_channel_count=num_channels,\n",
+ " sample_rate_hertz=sample_rate,\n",
+ " profanity_filter=True,\n",
+ " enable_automatic_punctuation=True,\n",
+ " language_code=\"en-US\",\n",
+ ")\n",
+ "\n",
+ "riva_tts = RivaTTS(\n",
+ " url=RIVA_SPEECH_URL, # the location of the Riva TTS server\n",
+ " output_directory=\"./scratch\", # location of the output .wav files\n",
+ " language_code=\"en-US\",\n",
+ " voice_name=\"English-US.Female-1\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f12049a2",
+ "metadata": {},
+ "source": [
+ "## 6. Create Additional Chain Components\n",
+ "As usual, declare the other parts of the chain. In this case, it's just a prompt template and an LLM.\n",
+ "\n",
+ "LangChain compatible NVIDIA LLMs from [NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) can also be used by following these [instructions](https://python.langchain.com/docs/integrations/chat/nvidia_ai_endpoints). "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "a6deb471",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain_core.prompts import PromptTemplate\n",
+ "from langchain_openai import OpenAI\n",
+ "\n",
+ "prompt = PromptTemplate.from_template(\"{user_input}\")\n",
+ "llm = OpenAI(openai_api_key=\"sk-xxx\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5cca78f1",
+ "metadata": {},
+ "source": [
+ "Now, tie together all the parts of the chain including RivaASR and RivaTTS."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "c8de3b75",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "chain = {\"user_input\": riva_asr} | prompt | llm | riva_tts"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "84c2c6dc",
+ "metadata": {},
+ "source": [
+ "## 7. Run the Chain with Streamed Inputs and Outputs\n",
+ "\n",
+ "### a. Mimic Audio Streaming\n",
+ "To mimic streaming, first convert the processed audio data to iterable chunks of audio bytes. \n",
+ "\n",
+ "Two functions, `producer` and `consumer`, respectively handle asynchronously passing audio data into the chain and consuming audio data out of the chain.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "745ee427",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import asyncio\n",
+ "\n",
+ "from langchain_community.utilities.nvidia_riva import AudioStream\n",
+ "\n",
+ "audio_chunks = [\n",
+ " audio_data[0 + i : chunk_size + i] for i in range(0, len(audio_data), chunk_size)\n",
+ "]\n",
+ "\n",
+ "\n",
+ "async def producer(input_stream) -> None:\n",
+ " \"\"\"Produces audio chunk bytes into an AudioStream as streaming audio input.\"\"\"\n",
+ " for chunk in audio_chunks:\n",
+ " await input_stream.aput(chunk)\n",
+ " input_stream.close()\n",
+ "\n",
+ "\n",
+ "async def consumer(input_stream, output_stream) -> None:\n",
+ " \"\"\"\n",
+ " Consumes audio chunks from input stream and passes them along the chain\n",
+ " constructed comprised of ASR -> text based prompt for an LLM -> TTS chunks\n",
+ " with synthesized voice of LLM response put in an output stream.\n",
+ " \"\"\"\n",
+ " while not input_stream.complete:\n",
+ " async for chunk in chain.astream(input_stream):\n",
+ " await output_stream.put(\n",
+ " chunk\n",
+ " ) # for production code don't forget to add a timeout\n",
+ "\n",
+ "\n",
+ "input_stream = AudioStream(maxsize=1000)\n",
+ "output_stream = asyncio.Queue()\n",
+ "\n",
+ "# send data into the chain\n",
+ "producer_task = asyncio.create_task(producer(input_stream))\n",
+ "# get data out of the chain\n",
+ "consumer_task = asyncio.create_task(consumer(input_stream, output_stream))\n",
+ "\n",
+ "while not consumer_task.done():\n",
+ " try:\n",
+ " generated_audio = await asyncio.wait_for(\n",
+ " output_stream.get(), timeout=2\n",
+ " ) # for production code don't forget to add a timeout\n",
+ " except asyncio.TimeoutError:\n",
+ " continue\n",
+ "\n",
+ "await producer_task\n",
+ "await consumer_task"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "76b8f175",
+ "metadata": {},
+ "source": [
+ "## 8. Listen to Voice Response\n",
+ "\n",
+ "The audio response is written to `./scratch` and should contain an audio clip that is a response to the input audio."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "8f41b939",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import glob\n",
+ "import os\n",
+ "\n",
+ "output_path = os.path.join(os.getcwd(), \"scratch\")\n",
+ "file_type = \"*.wav\"\n",
+ "files_path = os.path.join(output_path, file_type)\n",
+ "files = glob.glob(files_path)\n",
+ "\n",
+ "IPython.display.Audio(files[0])"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}