diff --git a/libs/langchain/langchain/document_loaders/parsers/audio.py b/libs/langchain/langchain/document_loaders/parsers/audio.py index 787bf7a279..fe394570a0 100644 --- a/libs/langchain/langchain/document_loaders/parsers/audio.py +++ b/libs/langchain/langchain/document_loaders/parsers/audio.py @@ -1,10 +1,13 @@ +import logging import time -from typing import Iterator, Optional +from typing import Dict, Iterator, Optional, Tuple from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob from langchain.schema import Document +logger = logging.getLogger(__name__) + class OpenAIWhisperParser(BaseBlobParser): """Transcribe and parse audio files. @@ -77,12 +80,31 @@ class OpenAIWhisperParser(BaseBlobParser): class OpenAIWhisperParserLocal(BaseBlobParser): """Transcribe and parse audio files. - Audio transcription is with OpenAI Whisper model locally from transformers - NOTE: By default uses the gpu if available, if you want to use cpu, - please set device = "cpu" + Audio transcription with OpenAI Whisper model locally from transformers + Parameters: + device - device to use + NOTE: By default uses the gpu if available, + if you want to use cpu, please set device = "cpu" + lang_model - whisper model to use, for example "openai/whisper-medium" + forced_decoder_ids - id states for decoder in multilanguage model, + usage example: + from transformers import WhisperProcessor + processor = WhisperProcessor.from_pretrained("openai/whisper-medium") + forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french", + task="transcribe") + forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french", + task="translate") + + + """ - def __init__(self, device: str = "0", lang_model: Optional[str] = None): + def __init__( + self, + device: str = "0", + lang_model: Optional[str] = None, + forced_decoder_ids: Optional[Tuple[Dict]] = None, + ): try: from transformers import pipeline except ImportError: @@ -136,10 +158,19 @@ class OpenAIWhisperParserLocal(BaseBlobParser): # load model for inference self.pipe = pipeline( "automatic-speech-recognition", - model="openai/whisper-medium", + model=self.lang_model, chunk_length_s=30, device=self.device, ) + if forced_decoder_ids is not None: + try: + self.pipe.model.config.forced_decoder_ids = forced_decoder_ids + except Exception as exception_text: + logger.info( + "Unable to set forced_decoder_ids parameter for whisper model" + f"Text of exception: {exception_text}" + "Therefore whisper model will use default mode for decoder" + ) def lazy_parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob."""