diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index 0a686ddee0..9311acc1b9 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -22,6 +22,8 @@ def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) - class ConversationBufferMemory(Memory, BaseModel): """Buffer for storing conversation memory.""" + ai_prefix: str = "AI" + """Prefix to use for AI generated responses.""" buffer: str = "" memory_key: str = "history" #: :meta private: @@ -43,7 +45,7 @@ class ConversationBufferMemory(Memory, BaseModel): if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") human = "Human: " + inputs[prompt_input_key] - ai = "AI: " + outputs[list(outputs.keys())[0]] + ai = f"{self.ai_prefix}: " + outputs[list(outputs.keys())[0]] self.buffer += "\n" + "\n".join([human, ai]) def clear(self) -> None: @@ -54,6 +56,8 @@ class ConversationBufferMemory(Memory, BaseModel): class ConversationalBufferWindowMemory(Memory, BaseModel): """Buffer for storing conversation memory.""" + ai_prefix: str = "AI" + """Prefix to use for AI generated responses.""" buffer: List[str] = Field(default_factory=list) memory_key: str = "history" #: :meta private: k: int = 5 @@ -76,7 +80,7 @@ class ConversationalBufferWindowMemory(Memory, BaseModel): if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") human = "Human: " + inputs[prompt_input_key] - ai = "AI: " + outputs[list(outputs.keys())[0]] + ai = f"{self.ai_prefix}: " + outputs[list(outputs.keys())[0]] self.buffer.append("\n".join([human, ai])) def clear(self) -> None: @@ -88,6 +92,8 @@ class ConversationSummaryMemory(Memory, BaseModel): """Conversation summarizer to memory.""" buffer: str = "" + ai_prefix: str = "AI" + """Prefix to use for AI generated responses.""" llm: BaseLLM prompt: BasePromptTemplate = SUMMARY_PROMPT memory_key: str = "history" #: :meta private: @@ -122,7 +128,7 @@ class ConversationSummaryMemory(Memory, BaseModel): if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") human = f"Human: {inputs[prompt_input_key]}" - ai = f"AI: {list(outputs.values())[0]}" + ai = f"{self.ai_prefix}: {list(outputs.values())[0]}" new_lines = "\n".join([human, ai]) chain = LLMChain(llm=self.llm, prompt=self.prompt) self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines) @@ -130,3 +136,88 @@ class ConversationSummaryMemory(Memory, BaseModel): def clear(self) -> None: """Clear memory contents.""" self.buffer = "" + + +class ConversationSummaryBufferMemory(Memory, BaseModel): + """Buffer with summarizer for storing conversation memory.""" + + buffer: List[str] = Field(default_factory=list) + max_token_limit: int = 2000 + moving_summary_buffer: str = "" + llm: BaseLLM + prompt: BasePromptTemplate = SUMMARY_PROMPT + memory_key: str = "history" + ai_prefix: str = "AI" + """Prefix to use for AI generated responses.""" + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """Return history buffer.""" + if self.moving_summary_buffer == "": + return {self.memory_key: "\n".join(self.buffer)} + memory_val = self.moving_summary_buffer + "\n" + "\n".join(self.buffer) + return {self.memory_key: memory_val} + + @root_validator() + def validate_prompt_input_variables(cls, values: Dict) -> Dict: + """Validate that prompt input variables are consistent.""" + prompt_variables = values["prompt"].input_variables + expected_keys = {"summary", "new_lines"} + if expected_keys != set(prompt_variables): + raise ValueError( + "Got unexpected prompt input variables. The prompt expects " + f"{prompt_variables}, but it should have {expected_keys}." + ) + return values + + def get_num_tokens_list(self, arr: List[str]) -> List[int]: + """Get list of number of tokens in each string in the input array.""" + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_num_tokens_list. " + "Please it install it with `pip install tiktoken`." + ) + # create a GPT-3 encoder instance + enc = tiktoken.get_encoding("gpt2") + + # encode the list of text using the GPT-3 encoder + tokenized_text = enc.encode_ordinary_batch(arr) + + # calculate the number of tokens for each encoded text in the list + return [len(x) for x in tokenized_text] + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save context from this conversation to buffer.""" + prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables) + if len(outputs) != 1: + raise ValueError(f"One output key expected, got {outputs.keys()}") + human = f"Human: {inputs[prompt_input_key]}" + ai = f"{self.ai_prefix}: {list(outputs.values())[0]}" + new_lines = "\n".join([human, ai]) + self.buffer.append(new_lines) + # Prune buffer if it exceeds max token limit + curr_buffer_length = sum(self.get_num_tokens_list(self.buffer)) + if curr_buffer_length > self.max_token_limit: + pruned_memory = [] + while curr_buffer_length > self.max_token_limit: + pruned_memory.append(self.buffer.pop(0)) + curr_buffer_length = sum(self.get_num_tokens_list(self.buffer)) + chain = LLMChain(llm=self.llm, prompt=self.prompt) + self.moving_summary_buffer = chain.predict( + summary=self.moving_summary_buffer, new_lines=("\n".join(pruned_memory)) + ) + + def clear(self) -> None: + """Clear memory contents.""" + self.buffer = [] + self.moving_summary_buffer = "" diff --git a/tests/integration_tests/chains/test_memory.py b/tests/integration_tests/chains/test_memory.py new file mode 100644 index 0000000000..20e723fe20 --- /dev/null +++ b/tests/integration_tests/chains/test_memory.py @@ -0,0 +1,31 @@ +"""Test memory functionality.""" +from langchain.chains.conversation.memory import ConversationSummaryBufferMemory +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_summary_buffer_memory_no_buffer_yet() -> None: + """Test ConversationSummaryBufferMemory when no inputs put in buffer yet.""" + memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") + output = memory.load_memory_variables({}) + assert output == {"baz": ""} + + +def test_summary_buffer_memory_buffer_only() -> None: + """Test ConversationSummaryBufferMemory when only buffer.""" + memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") + memory.save_context({"input": "bar"}, {"output": "foo"}) + assert memory.buffer == ["Human: bar\nAI: foo"] + output = memory.load_memory_variables({}) + assert output == {"baz": "Human: bar\nAI: foo"} + + +def test_summary_buffer_memory_summary() -> None: + """Test ConversationSummaryBufferMemory when only buffer.""" + memory = ConversationSummaryBufferMemory( + llm=FakeLLM(), memory_key="baz", max_token_limit=13 + ) + memory.save_context({"input": "bar"}, {"output": "foo"}) + memory.save_context({"input": "bar1"}, {"output": "foo1"}) + assert memory.buffer == ["Human: bar1\nAI: foo1"] + output = memory.load_memory_variables({}) + assert output == {"baz": "foo\nHuman: bar1\nAI: foo1"} diff --git a/tests/unit_tests/chains/test_conversation.py b/tests/unit_tests/chains/test_conversation.py index fd7eb55fb1..ce80fd6f56 100644 --- a/tests/unit_tests/chains/test_conversation.py +++ b/tests/unit_tests/chains/test_conversation.py @@ -12,6 +12,13 @@ from langchain.prompts.prompt import PromptTemplate from tests.unit_tests.llms.fake_llm import FakeLLM +def test_memory_ai_prefix() -> None: + """Test that ai_prefix in the memory component works.""" + memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant") + memory.save_context({"input": "bar"}, {"output": "foo"}) + assert memory.buffer == "\nHuman: bar\nAssistant: foo" + + def test_conversation_chain_works() -> None: """Test that conversation chain works in basic setting.""" llm = FakeLLM() @@ -42,6 +49,7 @@ def test_conversation_chain_errors_bad_variable() -> None: "memory", [ ConversationBufferMemory(memory_key="baz"), + ConversationalBufferWindowMemory(memory_key="baz"), ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"), ], ) @@ -81,7 +89,7 @@ def test_clearing_conversation_memory(memory: Memory) -> None: """Test clearing the conversation memory.""" # This is a good input because the input is not the same as baz. good_inputs = {"foo": "bar", "baz": "foo"} - # This is a good output because these is one variable. + # This is a good output because there is one variable. good_outputs = {"bar": "foo"} memory.save_context(good_inputs, good_outputs)