|
|
|
@ -35,6 +35,7 @@ from langchain_core.callbacks.manager import (
|
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
|
from langchain_core.messages import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
AIMessageChunk,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
FunctionMessage,
|
|
|
|
|
HumanMessage,
|
|
|
|
@ -339,11 +340,23 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|
|
|
|
parts = _convert_to_parts(message.content)
|
|
|
|
|
elif isinstance(message, FunctionMessage):
|
|
|
|
|
role = "user"
|
|
|
|
|
response: Any
|
|
|
|
|
if not isinstance(message.content, str):
|
|
|
|
|
response = message.content
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
response = json.loads(message.content)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
response = message.content # leave as str representation
|
|
|
|
|
parts = [
|
|
|
|
|
glm.Part(
|
|
|
|
|
function_response=glm.FunctionResponse(
|
|
|
|
|
name=message.name,
|
|
|
|
|
response=message.content,
|
|
|
|
|
response=(
|
|
|
|
|
{"output": response}
|
|
|
|
|
if not isinstance(response, dict)
|
|
|
|
|
else response
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
@ -364,12 +377,16 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
|
|
|
|
def _parse_response_candidate(
|
|
|
|
|
response_candidate: glm.Candidate, stream: bool
|
|
|
|
|
) -> AIMessage:
|
|
|
|
|
first_part = response_candidate.content.parts[0]
|
|
|
|
|
if first_part.function_call:
|
|
|
|
|
function_call = proto.Message.to_dict(first_part.function_call)
|
|
|
|
|
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
|
|
|
|
return AIMessage(content="", additional_kwargs={"function_call": function_call})
|
|
|
|
|
return (AIMessageChunk if stream else AIMessage)(
|
|
|
|
|
content="", additional_kwargs={"function_call": function_call}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
parts = response_candidate.content.parts
|
|
|
|
|
|
|
|
|
@ -377,11 +394,14 @@ def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
|
|
|
|
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
|
|
|
|
else:
|
|
|
|
|
content = [proto.Message.to_dict(part) for part in parts]
|
|
|
|
|
return AIMessage(content=content, additional_kwargs={})
|
|
|
|
|
return (AIMessageChunk if stream else AIMessage)(
|
|
|
|
|
content=content, additional_kwargs={}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _response_to_result(
|
|
|
|
|
response: glm.GenerateContentResponse,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
|
|
|
|
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
|
|
|
@ -397,8 +417,8 @@ def _response_to_result(
|
|
|
|
|
for safety_rating in candidate.safety_ratings
|
|
|
|
|
]
|
|
|
|
|
generations.append(
|
|
|
|
|
ChatGeneration(
|
|
|
|
|
message=_parse_response_candidate(candidate),
|
|
|
|
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
|
|
|
message=_parse_response_candidate(candidate, stream=stream),
|
|
|
|
|
generation_info=generation_info,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
@ -411,7 +431,10 @@ def _response_to_result(
|
|
|
|
|
f"Feedback: {response.prompt_feedback}"
|
|
|
|
|
)
|
|
|
|
|
generations = [
|
|
|
|
|
ChatGeneration(message=AIMessage(content=""), generation_info={})
|
|
|
|
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
|
|
|
message=(AIMessageChunk if stream else AIMessage)(content=""),
|
|
|
|
|
generation_info={},
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
|
|
|
|
|
|
@ -573,7 +596,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
|
|
stream=True,
|
|
|
|
|
)
|
|
|
|
|
for chunk in response:
|
|
|
|
|
_chat_result = _response_to_result(chunk)
|
|
|
|
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
|
|
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(gen.text)
|
|
|
|
@ -597,7 +620,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
|
|
generation_method=chat.send_message_async,
|
|
|
|
|
stream=True,
|
|
|
|
|
):
|
|
|
|
|
_chat_result = _response_to_result(chunk)
|
|
|
|
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
|
|
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
|
|
|
if run_manager:
|
|
|
|
|
await run_manager.on_llm_new_token(gen.text)
|
|
|
|
|