|
|
|
@ -5,6 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
AIMessage,
|
|
|
|
@ -216,7 +220,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
prompt = _messages_to_prompt_dict(messages)
|
|
|
|
|
|
|
|
|
@ -232,7 +239,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|
|
|
|
return _response_to_result(response, stop)
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
prompt = _messages_to_prompt_dict(messages)
|
|
|
|
|
|
|
|
|
|