Refactor elevenlabs tool

This commit is contained in:
Bagatur 2023-09-12 23:01:00 -07:00
parent 97122fb577
commit 79a567d885

View File

@ -1,20 +1,28 @@
import tempfile import tempfile
from typing import TYPE_CHECKING, Dict, Optional, Union from enum import Enum
from typing import Any, Dict, Optional, Union
from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import root_validator from langchain.pydantic_v1 import root_validator
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.eleven_labs.models import ElevenLabsModel
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
def _import_elevenlabs() -> Any:
try: try:
import elevenlabs import elevenlabs
except ImportError as e:
except ImportError:
raise ImportError( raise ImportError(
"elevenlabs is not installed. " "Run `pip install elevenlabs` to install." "Cannot import elevenlabs, please install `pip install elevenlabs`."
) ) from e
return elevenlabs
class ElevenLabsModel(str, Enum):
"""Models available for Eleven Labs Text2Speech."""
MULTI_LINGUAL = "eleven_multilingual_v1"
MONO_LINGUAL = "eleven_monolingual_v1"
class ElevenLabsText2SpeechTool(BaseTool): class ElevenLabsText2SpeechTool(BaseTool):
@ -41,24 +49,24 @@ class ElevenLabsText2SpeechTool(BaseTool):
return values return values
def _text2speech(self, text: str) -> str:
speech = elevenlabs.generate(text=text, model=self.model)
with tempfile.NamedTemporaryFile(mode="bx", suffix=".wav", delete=False) as f:
f.write(speech)
return f.name
def _run( def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> str:
"""Use the tool.""" """Use the tool."""
elevenlabs = _import_elevenlabs()
try: try:
speech_file = self._text2speech(query) speech = elevenlabs.generate(text=query, model=self.model)
return speech_file with tempfile.NamedTemporaryFile(
mode="bx", suffix=".wav", delete=False
) as f:
f.write(speech)
return f.name
except Exception as e: except Exception as e:
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}") raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
def play(self, speech_file: str) -> None: def play(self, speech_file: str) -> None:
"""Play the text as speech.""" """Play the text as speech."""
elevenlabs = _import_elevenlabs()
with open(speech_file, mode="rb") as f: with open(speech_file, mode="rb") as f:
speech = f.read() speech = f.read()
@ -67,5 +75,6 @@ class ElevenLabsText2SpeechTool(BaseTool):
def stream_speech(self, query: str) -> None: def stream_speech(self, query: str) -> None:
"""Stream the text as speech as it is generated. """Stream the text as speech as it is generated.
Play the text in your speakers.""" Play the text in your speakers."""
elevenlabs = _import_elevenlabs()
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True) speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
elevenlabs.stream(speech_stream) elevenlabs.stream(speech_stream)