mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[patch]: Skip OpenAIWhisperParser
extremely small audio chunks to avoid api error (#11450)
**Description** This PR addresses a rare issue in `OpenAIWhisperParser` that causes it to crash when processing an audio file with a duration very close to the class's chunk size threshold of 20 minutes. **Issue** #11449 **Dependencies** None **Tag maintainer** @agola11 @eyurtsev **Twitter handle** leonardodiegues --------- Co-authored-by: Leonardo Diegues <leonardo.diegues@grupofolha.com.br> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
46505742eb
commit
b15fccbb99
@ -13,10 +13,22 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class OpenAIWhisperParser(BaseBlobParser):
|
class OpenAIWhisperParser(BaseBlobParser):
|
||||||
"""Transcribe and parse audio files.
|
"""Transcribe and parse audio files.
|
||||||
Audio transcription is with OpenAI Whisper model."""
|
|
||||||
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
Audio transcription is with OpenAI Whisper model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API key
|
||||||
|
chunk_duration_threshold: minimum duration of a chunk in seconds
|
||||||
|
NOTE: According to the OpenAI API, the chunk duration should be at least 0.1
|
||||||
|
seconds. If the chunk duration is less or equal than the threshold,
|
||||||
|
it will be skipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, api_key: Optional[str] = None, *, chunk_duration_threshold: float = 0.1
|
||||||
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.chunk_duration_threshold = chunk_duration_threshold
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
@ -57,6 +69,9 @@ class OpenAIWhisperParser(BaseBlobParser):
|
|||||||
for split_number, i in enumerate(range(0, len(audio), chunk_duration_ms)):
|
for split_number, i in enumerate(range(0, len(audio), chunk_duration_ms)):
|
||||||
# Audio chunk
|
# Audio chunk
|
||||||
chunk = audio[i : i + chunk_duration_ms]
|
chunk = audio[i : i + chunk_duration_ms]
|
||||||
|
# Skip chunks that are too short to transcribe
|
||||||
|
if chunk.duration_seconds <= self.chunk_duration_threshold:
|
||||||
|
continue
|
||||||
file_obj = io.BytesIO(chunk.export(format="mp3").read())
|
file_obj = io.BytesIO(chunk.export(format="mp3").read())
|
||||||
if blob.source is not None:
|
if blob.source is not None:
|
||||||
file_obj.name = blob.source + f"_part_{split_number}.mp3"
|
file_obj.name = blob.source + f"_part_{split_number}.mp3"
|
||||||
|
Loading…
Reference in New Issue
Block a user