You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/unit_tests/utilities/test_nvidia_riva_asr.py

178 lines
4.9 KiB
Python

"""Unit tests to verify function of the Riva ASR implementation."""
from typing import TYPE_CHECKING, Any, Generator
from unittest.mock import patch
import pytest
from langchain_community.utilities.nvidia_riva import (
AudioStream,
RivaASR,
RivaAudioEncoding,
)
if TYPE_CHECKING:
import riva.client
import riva.client.proto.riva_asr_pb2 as rasr
AUDIO_DATA_MOCK = [
b"This",
b"is",
b"a",
b"test.",
b"_",
b"Hello.",
b"World",
]
AUDIO_TEXT_MOCK = b" ".join(AUDIO_DATA_MOCK).decode().strip().split("_")
SVC_URI = "not-a-url.asdf:9999"
SVC_USE_SSL = True
CONFIG = {
"audio_channel_count": 9,
"profanity_filter": False,
"enable_automatic_punctuation": False,
"url": f"{'https' if SVC_USE_SSL else 'http'}://{SVC_URI}",
"ssl_cert": "/dev/null",
"encoding": RivaAudioEncoding.ALAW,
"language_code": "not-a-language",
"sample_rate_hertz": 5,
}
def response_generator(
transcript: str = "",
empty: bool = False,
final: bool = False,
alternatives: bool = True,
) -> "rasr.StreamingRecognizeResponse":
"""Create a pseudo streaming response."""
# pylint: disable-next=import-outside-toplevel
import riva.client.proto.riva_asr_pb2 as rasr
if empty:
return rasr.StreamingRecognizeResponse()
if not alternatives:
return rasr.StreamingRecognizeResponse(
results=[
rasr.StreamingRecognitionResult(
is_final=final,
alternatives=[],
)
]
)
return rasr.StreamingRecognizeResponse(
results=[
rasr.StreamingRecognitionResult(
is_final=final,
alternatives=[
rasr.SpeechRecognitionAlternative(transcript=transcript.strip())
],
)
]
)
def streaming_recognize_mock(
generator: Generator["rasr.StreamingRecognizeRequest", None, None], **_: Any
) -> Generator["rasr.StreamingRecognizeResponse", None, None]:
"""A mock function to fake a streaming call to Riva."""
yield response_generator(empty=True)
yield response_generator(alternatives=False)
output_transcript = ""
for streaming_requests in generator:
input_bytes = streaming_requests.audio_content.decode()
final = input_bytes == "_"
if final:
input_bytes = ""
output_transcript += input_bytes + " "
yield response_generator(final=final, transcript=output_transcript)
if final:
output_transcript = ""
yield response_generator(final=True, transcript=output_transcript)
def riva_asr_stub_init_patch(
self: "riva.client.proto.riva_asr_pb2_grpc.RivaSpeechRecognitionStub", _: Any
) -> None:
"""Patch for the Riva asr library."""
self.StreamingRecognize = streaming_recognize_mock
@pytest.fixture
def asr() -> RivaASR:
"""Initialize a copy of the runnable."""
return RivaASR(**CONFIG)
@pytest.fixture
def stream() -> AudioStream:
"""Initialize and populate a sample audio stream."""
s = AudioStream()
for val in AUDIO_DATA_MOCK:
s.put(val)
s.close()
return s
@pytest.mark.requires("riva.client")
def test_init(asr: RivaASR) -> None:
"""Test that ASR accepts valid arguments."""
for key, expected_val in CONFIG.items():
assert getattr(asr, key, None) == expected_val
@pytest.mark.requires("riva.client")
def test_init_defaults() -> None:
"""Ensure the runnable can be loaded with no arguments."""
_ = RivaASR()
@pytest.mark.requires("riva.client")
def test_config(asr: RivaASR) -> None:
"""Verify the Riva config is properly assembled."""
# pylint: disable-next=import-outside-toplevel
import riva.client.proto.riva_asr_pb2 as rasr
expected = rasr.StreamingRecognitionConfig(
interim_results=True,
config=rasr.RecognitionConfig(
encoding=CONFIG["encoding"],
sample_rate_hertz=CONFIG["sample_rate_hertz"],
audio_channel_count=CONFIG["audio_channel_count"],
max_alternatives=1,
profanity_filter=CONFIG["profanity_filter"],
enable_automatic_punctuation=CONFIG["enable_automatic_punctuation"],
language_code=CONFIG["language_code"],
),
)
assert asr.config == expected
@pytest.mark.requires("riva.client")
def test_get_service(asr: RivaASR) -> None:
"""Test generating an asr service class."""
svc = asr._get_service()
assert str(svc.auth.ssl_cert) == CONFIG["ssl_cert"]
assert svc.auth.use_ssl == SVC_USE_SSL
assert svc.auth.uri == SVC_URI
@pytest.mark.requires("riva.client")
@patch(
"riva.client.proto.riva_asr_pb2_grpc.RivaSpeechRecognitionStub.__init__",
riva_asr_stub_init_patch,
)
def test_invoke(asr: RivaASR, stream: AudioStream) -> None:
"""Test the invoke method."""
got = asr.invoke(stream)
expected = " ".join([s.strip() for s in AUDIO_TEXT_MOCK]).strip()
assert got == expected