community[patch]: Refactor OpenAIWhisperParserLocal (#15150)

This PR addresses an issue in OpenAIWhisperParserLocal where requesting
CUDA without availability leads to an AttributeError #15143

Changes:

- Refactored Logic for CUDA Availability: The initialization now
includes a check for CUDA availability. If CUDA is not available, the
code falls back to using the CPU. This ensures seamless operation
without manual intervention.
- Parameterizing Batch Size and Chunk Size: The batch_size and
chunk_size are now configurable parameters, offering greater flexibility
and optimization options based on the specific requirements of the use
case.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/16049/head
andrijdavid 9 months ago committed by GitHub
parent 5cf06db3b3
commit d196646811
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,7 +64,7 @@ class OpenAIWhisperParser(BaseBlobParser):
file_obj.name = f"part_{split_number}.mp3"
# Transcribe
print(f"Transcribing part {split_number+1}!")
print(f"Transcribing part {split_number + 1}!")
attempts = 0
while attempts < 3:
try:
@ -116,6 +116,8 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
self,
device: str = "0",
lang_model: Optional[str] = None,
batch_size: int = 8,
chunk_length: int = 30,
forced_decoder_ids: Optional[Tuple[Dict]] = None,
):
"""Initialize the parser.
@ -126,6 +128,10 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
Defaults to None.
forced_decoder_ids: id states for decoder in a multilanguage model.
Defaults to None.
batch_size: batch size used for decoding
Defaults to 8.
chunk_length: chunk length used during inference.
Defaults to 30s.
"""
try:
from transformers import pipeline
@ -141,47 +147,37 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
"torch package not found, please install it with " "`pip install torch`"
)
# set device, cpu by default check if there is a GPU available
# Determine the device to use
if device == "cpu":
self.device = "cpu"
if lang_model is not None:
self.lang_model = lang_model
print("WARNING! Model override. Using model: ", self.lang_model)
else:
# unless overridden, use the small base model on cpu
self.lang_model = "openai/whisper-base"
else:
if torch.cuda.is_available():
self.device = "cuda:0"
# check GPU memory and select automatically the model
mem = torch.cuda.get_device_properties(self.device).total_memory / (
1024**2
)
if mem < 5000:
rec_model = "openai/whisper-base"
elif mem < 7000:
rec_model = "openai/whisper-small"
elif mem < 12000:
rec_model = "openai/whisper-medium"
else:
rec_model = "openai/whisper-large"
# check if model is overridden
if lang_model is not None:
self.lang_model = lang_model
print("WARNING! Model override. Might not fit in your GPU")
else:
self.lang_model = rec_model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
if self.device == "cpu":
default_model = "openai/whisper-base"
self.lang_model = lang_model if lang_model else default_model
else:
# Set the language model based on the device and available memory
mem = torch.cuda.get_device_properties(self.device).total_memory / (1024**2)
if mem < 5000:
rec_model = "openai/whisper-base"
elif mem < 7000:
rec_model = "openai/whisper-small"
elif mem < 12000:
rec_model = "openai/whisper-medium"
else:
"cpu"
rec_model = "openai/whisper-large"
self.lang_model = lang_model if lang_model else rec_model
print("Using the following model: ", self.lang_model)
self.batch_size = batch_size
# load model for inference
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.lang_model,
chunk_length_s=30,
chunk_length_s=chunk_length,
device=self.device,
)
if forced_decoder_ids is not None:
@ -224,7 +220,7 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
y, sr = librosa.load(file_obj, sr=16000)
prediction = self.pipe(y.copy(), batch_size=8)["text"]
prediction = self.pipe(y.copy(), batch_size=self.batch_size)["text"]
yield Document(
page_content=prediction,

Loading…
Cancel
Save