From 59853fc87628dfeedbe68a354a09134153ff0566 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Thu, 11 May 2023 15:33:52 -0700 Subject: [PATCH] add invocation params as extra params in llm callbacks (#4506) # Your PR Title (What it does) Fixes # (issue) ## Before submitting ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: --- langchain/llms/base.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index fefda42e..4168030f 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -149,6 +149,14 @@ class BaseLLM(BaseLanguageModel, ABC): "Argument 'prompts' is expected to be of type List[str], received" f" argument of type {type(prompts)}." ) + params = self.dict() + params["stop"] = stop + ( + existing_prompts, + llm_string, + missing_prompt_idxs, + missing_prompts, + ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache callback_manager = CallbackManager.configure( callbacks, self.callbacks, self.verbose @@ -163,7 +171,7 @@ class BaseLLM(BaseLanguageModel, ABC): "Asked to cache, but no cache found at `langchain.cache`." ) run_manager = callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts + {"name": self.__class__.__name__}, prompts, invocation_params=params ) try: output = ( @@ -176,17 +184,11 @@ class BaseLLM(BaseLanguageModel, ABC): raise e run_manager.on_llm_end(output) return output - params = self.dict() - params["stop"] = stop - ( - existing_prompts, - llm_string, - missing_prompt_idxs, - missing_prompts, - ) = get_prompts(params, prompts) if len(missing_prompts) > 0: run_manager = callback_manager.on_llm_start( - {"name": self.__class__.__name__}, missing_prompts + {"name": self.__class__.__name__}, + missing_prompts, + invocation_params=params, ) try: new_results = ( @@ -213,6 +215,14 @@ class BaseLLM(BaseLanguageModel, ABC): callbacks: Callbacks = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" + params = self.dict() + params["stop"] = stop + ( + existing_prompts, + llm_string, + missing_prompt_idxs, + missing_prompts, + ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, self.verbose @@ -227,7 +237,7 @@ class BaseLLM(BaseLanguageModel, ABC): "Asked to cache, but no cache found at `langchain.cache`." ) run_manager = await callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts + {"name": self.__class__.__name__}, prompts, invocation_params=params ) try: output = ( @@ -240,18 +250,11 @@ class BaseLLM(BaseLanguageModel, ABC): raise e await run_manager.on_llm_end(output, verbose=self.verbose) return output - params = self.dict() - params["stop"] = stop - ( - existing_prompts, - llm_string, - missing_prompt_idxs, - missing_prompts, - ) = get_prompts(params, prompts) if len(missing_prompts) > 0: run_manager = await callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts, + invocation_params=params, ) try: new_results = (