mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Refactor elevenlabs tool
This commit is contained in:
parent
97122fb577
commit
79a567d885
@ -1,20 +1,28 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||||
from langchain.pydantic_v1 import root_validator
|
from langchain.pydantic_v1 import root_validator
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.eleven_labs.models import ElevenLabsModel
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
def _import_elevenlabs() -> Any:
|
||||||
try:
|
try:
|
||||||
import elevenlabs
|
import elevenlabs
|
||||||
|
except ImportError as e:
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"elevenlabs is not installed. " "Run `pip install elevenlabs` to install."
|
"Cannot import elevenlabs, please install `pip install elevenlabs`."
|
||||||
)
|
) from e
|
||||||
|
return elevenlabs
|
||||||
|
|
||||||
|
|
||||||
|
class ElevenLabsModel(str, Enum):
|
||||||
|
"""Models available for Eleven Labs Text2Speech."""
|
||||||
|
|
||||||
|
MULTI_LINGUAL = "eleven_multilingual_v1"
|
||||||
|
MONO_LINGUAL = "eleven_monolingual_v1"
|
||||||
|
|
||||||
|
|
||||||
class ElevenLabsText2SpeechTool(BaseTool):
|
class ElevenLabsText2SpeechTool(BaseTool):
|
||||||
@ -41,24 +49,24 @@ class ElevenLabsText2SpeechTool(BaseTool):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _text2speech(self, text: str) -> str:
|
|
||||||
speech = elevenlabs.generate(text=text, model=self.model)
|
|
||||||
with tempfile.NamedTemporaryFile(mode="bx", suffix=".wav", delete=False) as f:
|
|
||||||
f.write(speech)
|
|
||||||
return f.name
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
|
elevenlabs = _import_elevenlabs()
|
||||||
try:
|
try:
|
||||||
speech_file = self._text2speech(query)
|
speech = elevenlabs.generate(text=query, model=self.model)
|
||||||
return speech_file
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="bx", suffix=".wav", delete=False
|
||||||
|
) as f:
|
||||||
|
f.write(speech)
|
||||||
|
return f.name
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
|
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}")
|
||||||
|
|
||||||
def play(self, speech_file: str) -> None:
|
def play(self, speech_file: str) -> None:
|
||||||
"""Play the text as speech."""
|
"""Play the text as speech."""
|
||||||
|
elevenlabs = _import_elevenlabs()
|
||||||
with open(speech_file, mode="rb") as f:
|
with open(speech_file, mode="rb") as f:
|
||||||
speech = f.read()
|
speech = f.read()
|
||||||
|
|
||||||
@ -67,5 +75,6 @@ class ElevenLabsText2SpeechTool(BaseTool):
|
|||||||
def stream_speech(self, query: str) -> None:
|
def stream_speech(self, query: str) -> None:
|
||||||
"""Stream the text as speech as it is generated.
|
"""Stream the text as speech as it is generated.
|
||||||
Play the text in your speakers."""
|
Play the text in your speakers."""
|
||||||
|
elevenlabs = _import_elevenlabs()
|
||||||
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
|
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True)
|
||||||
elevenlabs.stream(speech_stream)
|
elevenlabs.stream(speech_stream)
|
||||||
|
Loading…
Reference in New Issue
Block a user