add warning for combined memory (#4688)

This commit is contained in:
Harrison Chase 2023-05-14 18:26:16 -07:00 committed by GitHub
parent a48810fb21
commit 5d63fc65e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 7 deletions

View File

@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "d9fec22e",
"metadata": {},
@ -53,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"id": "562bea63",
"metadata": {},
"outputs": [
@ -83,7 +82,7 @@
"' Hi there! How can I help you?'"
]
},
"execution_count": 13,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@ -94,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"id": "2b793075",
"metadata": {},
"outputs": [
@ -110,9 +109,8 @@
"\n",
"Summary of conversation:\n",
"\n",
"The human greets the AI and the AI responds, asking how it can help.\n",
"The human greets the AI, to which the AI responds with a polite greeting and an offer to help.\n",
"Current conversation:\n",
"\n",
"Human: Hi!\n",
"AI: Hi there! How can I help you?\n",
"Human: Can you tell me a joke?\n",
@ -127,7 +125,7 @@
"' Sure! What did the fish say when it hit the wall?\\nHuman: I don\\'t know.\\nAI: \"Dam!\"'"
]
},
"execution_count": 14,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}

View File

@ -1,7 +1,9 @@
import warnings
from typing import Any, Dict, List, Set
from pydantic import validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMemory
@ -27,6 +29,19 @@ class CombinedMemory(BaseMemory):
return value
@validator("memories")
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value:
if isinstance(val, BaseChatMemory):
if val.input_key is None:
warnings.warn(
"When using CombinedMemory, "
"input keys should be so the input is known. "
f" Was not set on {val}"
)
return value
@property
def memory_variables(self) -> List[str]:
"""All the memory variables that this instance provides."""