From 900ad106d312b6876f9a03530f60d32d50cf629e Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Mon, 1 May 2023 16:19:31 -0700 Subject: [PATCH] Update google palm model signatures (#3920) Signatures out of date after callback refactors --- langchain/chat_models/google_palm.py | 14 ++++++++++++-- langchain/llms/google_palm.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/langchain/chat_models/google_palm.py b/langchain/chat_models/google_palm.py index ee0acd3c..4431918e 100644 --- a/langchain/chat_models/google_palm.py +++ b/langchain/chat_models/google_palm.py @@ -5,6 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from pydantic import BaseModel, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel from langchain.schema import ( AIMessage, @@ -216,7 +220,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel): return values def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) @@ -232,7 +239,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel): return _response_to_result(response, stop) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) diff --git a/langchain/llms/google_palm.py b/langchain/llms/google_palm.py index 2b71535b..ef418532 100644 --- a/langchain/llms/google_palm.py +++ b/langchain/llms/google_palm.py @@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms import BaseLLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -74,7 +78,10 @@ class GooglePalm(BaseLLM, BaseModel): return values def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: generations = [] for prompt in prompts: @@ -99,7 +106,10 @@ class GooglePalm(BaseLLM, BaseModel): return LLMResult(generations=generations) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: raise NotImplementedError()