mirror of https://github.com/hwchase17/langchain
community[minor]: add hugging face text-to-speech inference API (#18880)
Description: I implemented a tool to use Hugging Face text-to-speech inference API. Issue: n/a Dependencies: n/a Twitter handle: No Twitter, but do have [LinkedIn](https://www.linkedin.com/in/robby-horvath/) lol. --------- Co-authored-by: Robby <h0rv@users.noreply.github.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>pull/19762/head
parent
73eb3f8fd9
commit
f7e8a382cc
@ -0,0 +1,7 @@
|
|||||||
|
from langchain_community.tools.audio.huggingface_text_to_speech_inference import (
|
||||||
|
HuggingFaceTextToSpeechModelInference,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HuggingFaceTextToSpeechModelInference",
|
||||||
|
]
|
@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Callable, Literal, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceTextToSpeechModelInference(BaseTool):
|
||||||
|
"""HuggingFace Text-to-Speech Model Inference.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Environment variable ``HUGGINGFACE_API_KEY`` must be set,
|
||||||
|
or passed as a named parameter to the constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "openai_text_to_speech"
|
||||||
|
"""Name of the tool."""
|
||||||
|
description: str = "A wrapper around OpenAI Text-to-Speech API. "
|
||||||
|
"""Description of the tool."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""Model name."""
|
||||||
|
file_extension: str
|
||||||
|
"""File extension of the output audio file."""
|
||||||
|
destination_dir: str
|
||||||
|
"""Directory to save the output audio file."""
|
||||||
|
file_namer: Callable[[], str]
|
||||||
|
"""Function to generate unique file names."""
|
||||||
|
|
||||||
|
api_url: str
|
||||||
|
huggingface_api_key: SecretStr
|
||||||
|
|
||||||
|
_HUGGINGFACE_API_KEY_ENV_NAME = "HUGGINGFACE_API_KEY"
|
||||||
|
_HUGGINGFACE_API_URL_ROOT = "https://api-inference.huggingface.co/models"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
file_extension: str,
|
||||||
|
*,
|
||||||
|
destination_dir: str = "./tts",
|
||||||
|
file_naming_func: Literal["uuid", "timestamp"] = "uuid",
|
||||||
|
huggingface_api_key: Optional[SecretStr] = None,
|
||||||
|
) -> None:
|
||||||
|
if not huggingface_api_key:
|
||||||
|
huggingface_api_key = SecretStr(
|
||||||
|
os.getenv(self._HUGGINGFACE_API_KEY_ENV_NAME, "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not huggingface_api_key
|
||||||
|
or not huggingface_api_key.get_secret_value()
|
||||||
|
or huggingface_api_key.get_secret_value() == ""
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"'{self._HUGGINGFACE_API_KEY_ENV_NAME}' must be or set or passed"
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_naming_func == "uuid":
|
||||||
|
file_namer = lambda: str(uuid.uuid4()) # noqa: E731
|
||||||
|
elif file_naming_func == "timestamp":
|
||||||
|
file_namer = lambda: str(int(datetime.now().timestamp())) # noqa: E731
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid value for 'file_naming_func': {file_naming_func}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
file_extension=file_extension,
|
||||||
|
api_url=f"{self._HUGGINGFACE_API_URL_ROOT}/{model}",
|
||||||
|
destination_dir=destination_dir,
|
||||||
|
file_namer=file_namer,
|
||||||
|
huggingface_api_key=huggingface_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
|
) -> str:
|
||||||
|
response = requests.post(
|
||||||
|
self.api_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.huggingface_api_key.get_secret_value()}"
|
||||||
|
},
|
||||||
|
json={"inputs": query},
|
||||||
|
)
|
||||||
|
audio_bytes = response.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.makedirs(self.destination_dir, exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating directory '{self.destination_dir}': {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
output_file = os.path.join(
|
||||||
|
self.destination_dir,
|
||||||
|
f"{str(self.file_namer())}.{self.file_extension}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(output_file, mode="xb") as f:
|
||||||
|
f.write(audio_bytes)
|
||||||
|
except FileExistsError:
|
||||||
|
raise ValueError("Output name must be unique")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred while creating file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return output_file
|
@ -0,0 +1,87 @@
|
|||||||
|
"""Test Audio Tools."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import Mock, mock_open, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
|
|
||||||
|
from langchain_community.tools.audio import HuggingFaceTextToSpeechModelInference
|
||||||
|
|
||||||
|
AUDIO_FORMAT_EXT = "wav"
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_tts_constructor() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
os.environ.pop("HUGGINGFACE_API_KEY", None)
|
||||||
|
HuggingFaceTextToSpeechModelInference(
|
||||||
|
model="test/model",
|
||||||
|
file_extension=AUDIO_FORMAT_EXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
HuggingFaceTextToSpeechModelInference(
|
||||||
|
model="test/model",
|
||||||
|
file_extension=AUDIO_FORMAT_EXT,
|
||||||
|
huggingface_api_key=SecretStr(""),
|
||||||
|
)
|
||||||
|
|
||||||
|
HuggingFaceTextToSpeechModelInference(
|
||||||
|
model="test/model",
|
||||||
|
file_extension=AUDIO_FORMAT_EXT,
|
||||||
|
huggingface_api_key=SecretStr("foo"),
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "foo"
|
||||||
|
HuggingFaceTextToSpeechModelInference(
|
||||||
|
model="test/model",
|
||||||
|
file_extension=AUDIO_FORMAT_EXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_tts_run_with_requests_mock() -> None:
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir, patch(
|
||||||
|
"uuid.uuid4"
|
||||||
|
) as mock_uuid, patch("requests.post") as mock_inference, patch(
|
||||||
|
"builtins.open", mock_open()
|
||||||
|
) as mock_file:
|
||||||
|
input_query = "Dummy input"
|
||||||
|
|
||||||
|
mock_uuid_value = uuid.UUID("00000000-0000-0000-0000-000000000000")
|
||||||
|
mock_uuid.return_value = mock_uuid_value
|
||||||
|
|
||||||
|
expected_output_file_base_name = os.path.join(tmp_dir, str(mock_uuid_value))
|
||||||
|
expected_output_file = f"{expected_output_file_base_name}.{AUDIO_FORMAT_EXT}"
|
||||||
|
|
||||||
|
test_audio_content = b"test_audio_bytes"
|
||||||
|
|
||||||
|
tts = HuggingFaceTextToSpeechModelInference(
|
||||||
|
model="test/model",
|
||||||
|
file_extension=AUDIO_FORMAT_EXT,
|
||||||
|
destination_dir=tmp_dir,
|
||||||
|
file_naming_func="uuid",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the requests.post response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.content = test_audio_content
|
||||||
|
mock_inference.return_value = mock_response
|
||||||
|
|
||||||
|
output_path = tts._run(input_query)
|
||||||
|
|
||||||
|
assert output_path == expected_output_file
|
||||||
|
|
||||||
|
mock_inference.assert_called_once_with(
|
||||||
|
tts.api_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {tts.huggingface_api_key.get_secret_value()}"
|
||||||
|
},
|
||||||
|
json={"inputs": input_query},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_file.assert_called_once_with(expected_output_file, mode="xb")
|
||||||
|
mock_file.return_value.write.assert_called_once_with(test_audio_content)
|
Loading…
Reference in New Issue