langchain/libs/community/langchain_community/document_loaders/parsers/audio.py
andrijdavid d196646811
community[patch]: Refactor OpenAIWhisperParserLocal (#15150)
This PR addresses an issue in OpenAIWhisperParserLocal where requesting
CUDA without availability leads to an AttributeError #15143

Changes:

- Refactored Logic for CUDA Availability: The initialization now
includes a check for CUDA availability. If CUDA is not available, the
code falls back to using the CPU. This ensures seamless operation
without manual intervention.
- Parameterizing Batch Size and Chunk Size: The batch_size and
chunk_size are now configurable parameters, offering greater flexibility
and optimization options based on the specific requirements of the use
case.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2024-01-15 12:29:14 -08:00

307 lines
10 KiB
Python

import logging
import time
from typing import Dict, Iterator, Optional, Tuple
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseBlobParser
from langchain_community.document_loaders.blob_loaders import Blob
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
class OpenAIWhisperParser(BaseBlobParser):
"""Transcribe and parse audio files.
Audio transcription is with OpenAI Whisper model."""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import io
try:
import openai
except ImportError:
raise ImportError(
"openai package not found, please install it with "
"`pip install openai`"
)
try:
from pydub import AudioSegment
except ImportError:
raise ImportError(
"pydub package not found, please install it with " "`pip install pydub`"
)
if is_openai_v1():
# api_key optional, defaults to `os.environ['OPENAI_API_KEY']`
client = openai.OpenAI(api_key=self.api_key)
else:
# Set the API key if provided
if self.api_key:
openai.api_key = self.api_key
# Audio file from disk
audio = AudioSegment.from_file(blob.path)
# Define the duration of each chunk in minutes
# Need to meet 25MB size limit for Whisper API
chunk_duration = 20
chunk_duration_ms = chunk_duration * 60 * 1000
# Split the audio into chunk_duration_ms chunks
for split_number, i in enumerate(range(0, len(audio), chunk_duration_ms)):
# Audio chunk
chunk = audio[i : i + chunk_duration_ms]
file_obj = io.BytesIO(chunk.export(format="mp3").read())
if blob.source is not None:
file_obj.name = blob.source + f"_part_{split_number}.mp3"
else:
file_obj.name = f"part_{split_number}.mp3"
# Transcribe
print(f"Transcribing part {split_number + 1}!")
attempts = 0
while attempts < 3:
try:
if is_openai_v1():
transcript = client.audio.transcriptions.create(
model="whisper-1", file=file_obj
)
else:
transcript = openai.Audio.transcribe("whisper-1", file_obj)
break
except Exception as e:
attempts += 1
print(f"Attempt {attempts} failed. Exception: {str(e)}")
time.sleep(5)
else:
print("Failed to transcribe after 3 attempts.")
continue
yield Document(
page_content=transcript.text,
metadata={"source": blob.source, "chunk": split_number},
)
class OpenAIWhisperParserLocal(BaseBlobParser):
"""Transcribe and parse audio files with OpenAI Whisper model.
Audio transcription with OpenAI Whisper model locally from transformers.
Parameters:
device - device to use
NOTE: By default uses the gpu if available,
if you want to use cpu, please set device = "cpu"
lang_model - whisper model to use, for example "openai/whisper-medium"
forced_decoder_ids - id states for decoder in multilanguage model,
usage example:
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
task="transcribe")
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
task="translate")
"""
def __init__(
self,
device: str = "0",
lang_model: Optional[str] = None,
batch_size: int = 8,
chunk_length: int = 30,
forced_decoder_ids: Optional[Tuple[Dict]] = None,
):
"""Initialize the parser.
Args:
device: device to use.
lang_model: whisper model to use, for example "openai/whisper-medium".
Defaults to None.
forced_decoder_ids: id states for decoder in a multilanguage model.
Defaults to None.
batch_size: batch size used for decoding
Defaults to 8.
chunk_length: chunk length used during inference.
Defaults to 30s.
"""
try:
from transformers import pipeline
except ImportError:
raise ImportError(
"transformers package not found, please install it with "
"`pip install transformers`"
)
try:
import torch
except ImportError:
raise ImportError(
"torch package not found, please install it with " "`pip install torch`"
)
# Determine the device to use
if device == "cpu":
self.device = "cpu"
else:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
if self.device == "cpu":
default_model = "openai/whisper-base"
self.lang_model = lang_model if lang_model else default_model
else:
# Set the language model based on the device and available memory
mem = torch.cuda.get_device_properties(self.device).total_memory / (1024**2)
if mem < 5000:
rec_model = "openai/whisper-base"
elif mem < 7000:
rec_model = "openai/whisper-small"
elif mem < 12000:
rec_model = "openai/whisper-medium"
else:
rec_model = "openai/whisper-large"
self.lang_model = lang_model if lang_model else rec_model
print("Using the following model: ", self.lang_model)
self.batch_size = batch_size
# load model for inference
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.lang_model,
chunk_length_s=chunk_length,
device=self.device,
)
if forced_decoder_ids is not None:
try:
self.pipe.model.config.forced_decoder_ids = forced_decoder_ids
except Exception as exception_text:
logger.info(
"Unable to set forced_decoder_ids parameter for whisper model"
f"Text of exception: {exception_text}"
"Therefore whisper model will use default mode for decoder"
)
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import io
try:
from pydub import AudioSegment
except ImportError:
raise ImportError(
"pydub package not found, please install it with `pip install pydub`"
)
try:
import librosa
except ImportError:
raise ImportError(
"librosa package not found, please install it with "
"`pip install librosa`"
)
# Audio file from disk
audio = AudioSegment.from_file(blob.path)
file_obj = io.BytesIO(audio.export(format="mp3").read())
# Transcribe
print(f"Transcribing part {blob.path}!")
y, sr = librosa.load(file_obj, sr=16000)
prediction = self.pipe(y.copy(), batch_size=self.batch_size)["text"]
yield Document(
page_content=prediction,
metadata={"source": blob.source},
)
class YandexSTTParser(BaseBlobParser):
"""Transcribe and parse audio files.
Audio transcription is with OpenAI Whisper model."""
def __init__(
self,
*,
api_key: Optional[str] = None,
iam_token: Optional[str] = None,
model: str = "general",
language: str = "auto",
):
"""Initialize the parser.
Args:
api_key: API key for a service account
with the `ai.speechkit-stt.user` role.
iam_token: IAM token for a service account
with the `ai.speechkit-stt.user` role.
model: Recognition model name.
Defaults to general.
language: The language in ISO 639-1 format.
Defaults to automatic language recognition.
Either `api_key` or `iam_token` must be provided, but not both.
"""
if (api_key is None) == (iam_token is None):
raise ValueError(
"Either 'api_key' or 'iam_token' must be provided, but not both."
)
self.api_key = api_key
self.iam_token = iam_token
self.model = model
self.language = language
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
try:
from speechkit import configure_credentials, creds, model_repository
from speechkit.stt import AudioProcessingType
except ImportError:
raise ImportError(
"yandex-speechkit package not found, please install it with "
"`pip install yandex-speechkit`"
)
try:
from pydub import AudioSegment
except ImportError:
raise ImportError(
"pydub package not found, please install it with " "`pip install pydub`"
)
if self.api_key:
configure_credentials(
yandex_credentials=creds.YandexCredentials(api_key=self.api_key)
)
else:
configure_credentials(
yandex_credentials=creds.YandexCredentials(iam_token=self.iam_token)
)
audio = AudioSegment.from_file(blob.path)
model = model_repository.recognition_model()
model.model = self.model
model.language = self.language
model.audio_processing_type = AudioProcessingType.Full
result = model.transcribe(audio)
for res in result:
yield Document(
page_content=res.normalized_text,
metadata={"source": blob.source},
)