Update google palm model signatures (#3920)

Signatures out of date after callback refactors
fix_agent_callbacks
Davis Chase 1 year ago committed by GitHub
parent 145ff23fb1
commit 900ad106d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
@ -216,7 +220,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
return values
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
@ -232,7 +239,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
return _response_to_result(response, stop)
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)

@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
@ -74,7 +78,10 @@ class GooglePalm(BaseLLM, BaseModel):
return values
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
generations = []
for prompt in prompts:
@ -99,7 +106,10 @@ class GooglePalm(BaseLLM, BaseModel):
return LLMResult(generations=generations)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
raise NotImplementedError()

Loading…
Cancel
Save