From 0db05b6725c7ba72dc40c2f686d8ba6a5d24e640 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 3 Jan 2023 08:03:50 -0800 Subject: [PATCH] Harrison/add human prefix (#520) Co-authored-by: Andrew Huang --- .../conversational_customization.ipynb | 133 +++++++++++++++++- langchain/chains/conversation/memory.py | 12 +- tests/unit_tests/chains/test_conversation.py | 7 + 3 files changed, 144 insertions(+), 8 deletions(-) diff --git a/docs/modules/memory/examples/conversational_customization.ipynb b/docs/modules/memory/examples/conversational_customization.ipynb index 3c7276a6d4..7efca50b95 100644 --- a/docs/modules/memory/examples/conversational_customization.ipynb +++ b/docs/modules/memory/examples/conversational_customization.ipynb @@ -7,9 +7,7 @@ "source": [ "# Conversational Memory Customization\n", "\n", - "This notebook walks through a few ways to customize conversational memory.\n", - "\n", - "The main way to do so is by changing the AI prefix in the conversation summary. By default, this is set to \"AI\", but you can set this to be anything you want. Note that if you change this, you should also change the prompt used in the chain to reflect this naming change. Let's walk through an example of that in the example below." + "This notebook walks through a few ways to customize conversational memory." ] }, { @@ -27,6 +25,16 @@ "llm = OpenAI(temperature=0)" ] }, + { + "cell_type": "markdown", + "id": "fe3cd3e9", + "metadata": {}, + "source": [ + "## AI Prefix\n", + "\n", + "The first way to do so is by changing the AI prefix in the conversation summary. By default, this is set to \"AI\", but you can set this to be anything you want. Note that if you change this, you should also change the prompt used in the chain to reflect this naming change. Let's walk through an example of that in the example below." + ] + }, { "cell_type": "code", "execution_count": 2, @@ -229,10 +237,127 @@ "conversation.predict(input=\"What's the weather?\")" ] }, + { + "cell_type": "markdown", + "id": "0517ccf8", + "metadata": {}, + "source": [ + "## Human Prefix\n", + "\n", + "The next way to do so is by changing the Human prefix in the conversation summary. By default, this is set to \"Human\", but you can set this to be anything you want. Note that if you change this, you should also change the prompt used in the chain to reflect this naming change. Let's walk through an example of that in the example below." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6357a461", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we can override it and set it to \"Friend\"\n", + "from langchain.prompts.prompt import PromptTemplate\n", + "\n", + "template = \"\"\"The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "{history}\n", + "Friend: {input}\n", + "AI:\"\"\"\n", + "PROMPT = PromptTemplate(\n", + " input_variables=[\"history\", \"input\"], template=template\n", + ")\n", + "conversation = ConversationChain(\n", + " prompt=PROMPT,\n", + " llm=llm, \n", + " verbose=True, \n", + " memory=ConversationBufferMemory(human_prefix=\"Friend\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "969b6f54", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "\n", + "Friend: Hi there!\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished ConversationChain chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" Hi there! It's nice to meet you. How can I help you today?\"" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conversation.predict(input=\"Hi there!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d5ea82bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "\n", + "Friend: Hi there!\n", + "AI: Hi there! It's nice to meet you. How can I help you today?\n", + "Friend: What's the weather?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished ConversationChain chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' The weather right now is sunny and warm with a temperature of 75 degrees Fahrenheit. The forecast for the rest of the day is mostly sunny with a high of 82 degrees.'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conversation.predict(input=\"What's the weather?\")" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "1023b6ef", + "id": "ce7f79ab", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index 12b5616411..9dfbd25b02 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -22,6 +22,7 @@ def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) - class ConversationBufferMemory(Memory, BaseModel): """Buffer for storing conversation memory.""" + human_prefix: str = "Human" ai_prefix: str = "AI" """Prefix to use for AI generated responses.""" buffer: str = "" @@ -53,7 +54,7 @@ class ConversationBufferMemory(Memory, BaseModel): output_key = list(outputs.keys())[0] else: output_key = self.output_key - human = "Human: " + inputs[prompt_input_key] + human = f"{self.human_prefix}: " + inputs[prompt_input_key] ai = f"{self.ai_prefix}: " + outputs[output_key] self.buffer += "\n" + "\n".join([human, ai]) @@ -65,6 +66,7 @@ class ConversationBufferMemory(Memory, BaseModel): class ConversationBufferWindowMemory(Memory, BaseModel): """Buffer for storing conversation memory.""" + human_prefix: str = "Human" ai_prefix: str = "AI" """Prefix to use for AI generated responses.""" buffer: List[str] = Field(default_factory=list) @@ -97,7 +99,7 @@ class ConversationBufferWindowMemory(Memory, BaseModel): output_key = list(outputs.keys())[0] else: output_key = self.output_key - human = "Human: " + inputs[prompt_input_key] + human = f"{self.human_prefix}: " + inputs[prompt_input_key] ai = f"{self.ai_prefix}: " + outputs[output_key] self.buffer.append("\n".join([human, ai])) @@ -114,6 +116,7 @@ class ConversationSummaryMemory(Memory, BaseModel): """Conversation summarizer to memory.""" buffer: str = "" + human_prefix: str = "Human" ai_prefix: str = "AI" """Prefix to use for AI generated responses.""" llm: BaseLLM @@ -158,7 +161,7 @@ class ConversationSummaryMemory(Memory, BaseModel): output_key = list(outputs.keys())[0] else: output_key = self.output_key - human = f"Human: {inputs[prompt_input_key]}" + human = f"{self.human_prefix}: {inputs[prompt_input_key]}" ai = f"{self.ai_prefix}: {outputs[output_key]}" new_lines = "\n".join([human, ai]) chain = LLMChain(llm=self.llm, prompt=self.prompt) @@ -178,6 +181,7 @@ class ConversationSummaryBufferMemory(Memory, BaseModel): llm: BaseLLM prompt: BasePromptTemplate = SUMMARY_PROMPT memory_key: str = "history" + human_prefix: str = "Human" ai_prefix: str = "AI" """Prefix to use for AI generated responses.""" output_key: Optional[str] = None @@ -226,7 +230,7 @@ class ConversationSummaryBufferMemory(Memory, BaseModel): output_key = list(outputs.keys())[0] else: output_key = self.output_key - human = f"Human: {inputs[prompt_input_key]}" + human = f"{self.human_prefix}: {inputs[prompt_input_key]}" ai = f"{self.ai_prefix}: {outputs[output_key]}" new_lines = "\n".join([human, ai]) self.buffer.append(new_lines) diff --git a/tests/unit_tests/chains/test_conversation.py b/tests/unit_tests/chains/test_conversation.py index 08eb0f0ef2..66374aeb07 100644 --- a/tests/unit_tests/chains/test_conversation.py +++ b/tests/unit_tests/chains/test_conversation.py @@ -19,6 +19,13 @@ def test_memory_ai_prefix() -> None: assert memory.buffer == "\nHuman: bar\nAssistant: foo" +def test_memory_human_prefix() -> None: + """Test that human_prefix in the memory component works.""" + memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend") + memory.save_context({"input": "bar"}, {"output": "foo"}) + assert memory.buffer == "\nFriend: bar\nAI: foo" + + def test_conversation_chain_works() -> None: """Test that conversation chain works in basic setting.""" llm = FakeLLM()