From fe78aff1f242ae25fac84d10b7d257e1f60f4da5 Mon Sep 17 00:00:00 2001 From: idcore <30922976+idcore@users.noreply.github.com> Date: Mon, 7 Aug 2023 23:17:58 +0300 Subject: [PATCH] Add new parameter forced_decoder_ids to OpenAIWhisperParserLocal + small bug fix (#8793) - Description: new parameter forced_decoder_ids for OpenAIWhisperParserLocal to force input language, and enable optional translate mode. Usage example: processor = WhisperProcessor.from_pretrained("openai/whisper-medium") forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="transcribe") #forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="translate") loader = GenericLoader(YoutubeAudioLoader(urls, save_dir), OpenAIWhisperParserLocal(lang_model="openai/whisper-medium",forced_decoder_ids=forced_decoder_ids)) - Issue #8792 - Tag maintainer: @rlancemartin, @eyurtsev --------- Co-authored-by: idcore --- .../document_loaders/parsers/audio.py | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/document_loaders/parsers/audio.py b/libs/langchain/langchain/document_loaders/parsers/audio.py index 787bf7a279..fe394570a0 100644 --- a/libs/langchain/langchain/document_loaders/parsers/audio.py +++ b/libs/langchain/langchain/document_loaders/parsers/audio.py @@ -1,10 +1,13 @@ +import logging import time -from typing import Iterator, Optional +from typing import Dict, Iterator, Optional, Tuple from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob from langchain.schema import Document +logger = logging.getLogger(__name__) + class OpenAIWhisperParser(BaseBlobParser): """Transcribe and parse audio files. @@ -77,12 +80,31 @@ class OpenAIWhisperParser(BaseBlobParser): class OpenAIWhisperParserLocal(BaseBlobParser): """Transcribe and parse audio files. - Audio transcription is with OpenAI Whisper model locally from transformers - NOTE: By default uses the gpu if available, if you want to use cpu, - please set device = "cpu" + Audio transcription with OpenAI Whisper model locally from transformers + Parameters: + device - device to use + NOTE: By default uses the gpu if available, + if you want to use cpu, please set device = "cpu" + lang_model - whisper model to use, for example "openai/whisper-medium" + forced_decoder_ids - id states for decoder in multilanguage model, + usage example: + from transformers import WhisperProcessor + processor = WhisperProcessor.from_pretrained("openai/whisper-medium") + forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french", + task="transcribe") + forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french", + task="translate") + + + """ - def __init__(self, device: str = "0", lang_model: Optional[str] = None): + def __init__( + self, + device: str = "0", + lang_model: Optional[str] = None, + forced_decoder_ids: Optional[Tuple[Dict]] = None, + ): try: from transformers import pipeline except ImportError: @@ -136,10 +158,19 @@ class OpenAIWhisperParserLocal(BaseBlobParser): # load model for inference self.pipe = pipeline( "automatic-speech-recognition", - model="openai/whisper-medium", + model=self.lang_model, chunk_length_s=30, device=self.device, ) + if forced_decoder_ids is not None: + try: + self.pipe.model.config.forced_decoder_ids = forced_decoder_ids + except Exception as exception_text: + logger.info( + "Unable to set forced_decoder_ids parameter for whisper model" + f"Text of exception: {exception_text}" + "Therefore whisper model will use default mode for decoder" + ) def lazy_parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob."""