forked from Archives/langchain
Samantha/add conversation chain (#166)
Add MemoryChain and ConversationChain as chains that take a docstore in addition to the prompt, and use the docstore to stuff context into the prompt. This can be used to have an ongoing conversation with a chatbot. Probably needs a bit of refactoring for code quality Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>harrison/flexible_model_args
parent
4334ffa6f9
commit
a408ed3ea3
@ -0,0 +1,283 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "ae046bff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"Human: Hi there!\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Hello! How are you today?'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain import OpenAI, ConversationChain\n",
|
||||
"from langchain.chains.conversation.memory import ConversationSummaryMemory\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"conversation = ConversationChain(llm=llm, verbose=True)\n",
|
||||
"\n",
|
||||
"conversation.predict(input=\"Hi there!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d8e2a6ff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"Human: Hi there!\n",
|
||||
"AI: Hello! How are you today?\n",
|
||||
"Human: I'm doing well! Just having a conversation with an AI.\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" That's great! What would you like to talk about?\""
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation.predict(input=\"I'm doing well! Just having a conversation with an AI.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "15eda316",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"Human: Hi there!\n",
|
||||
"AI: Hello! How are you today?\n",
|
||||
"Human: I'm doing well! Just having a conversation with an AI.\n",
|
||||
"AI: That's great! What would you like to talk about?\n",
|
||||
"Human: Tell me about yourself.\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' I am an AI created to provide information and support to humans. I enjoy learning and exploring new things.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation.predict(input=\"Tell me about yourself.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "b7274f2c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"Human: Hi, what's up?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"\\n\\nI'm doing well, thank you for asking. I'm currently working on a project that I'm really excited about.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation_with_summary = ConversationChain(llm=llm, memory=ConversationSummaryMemory(llm=OpenAI()), verbose=True)\n",
|
||||
"conversation_with_summary.predict(input=\"Hi, what's up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "a6b6b88f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"The human greets the AI and asks how it is doing. The AI responds that it is doing well and is currently working on a project that it is excited about.\n",
|
||||
"Human: Tell me more about it!\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"\\n\\nI'm working on a project that involves helping people to better understand and use artificial intelligence. I'm really excited about it because I think it has the potential to make a big difference in people's lives.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation_with_summary.predict(input=\"Tell me more about it!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "dad869fe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The human greets the AI and asks how it is doing. The AI responds that it is doing well and is currently working on a project that it is excited about - a project that involves helping people to better understand and use artificial intelligence.\n",
|
||||
"Human: Very cool -- what is the scope of the project?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nThe project is still in the early stages, but the goal is to create a resource that will help people to understand artificial intelligence and how to use it effectively.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation_with_summary.predict(input=\"Very cool -- what is the scope of the project?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0eb11bd0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1 @@
|
||||
"""Chain that carries on a conversation from a prompt plus history."""
|
@ -0,0 +1,61 @@
|
||||
"""Chain that carries on a conversation and calls an LLM."""
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.base import Memory
|
||||
from langchain.chains.conversation.memory import ConversationBufferMemory
|
||||
from langchain.chains.conversation.prompt import PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class ConversationChain(LLMChain, BaseModel):
|
||||
"""Chain to have a conversation and load context from memory.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import ConversationChain, OpenAI
|
||||
conversation = ConversationChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
memory: Memory = Field(default_factory=ConversationBufferMemory)
|
||||
"""Default memory store."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""Default conversation prompt to use."""
|
||||
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "response" #: :meta private:
|
||||
buffer: str = "" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Use this since so some prompt vars come from history."""
|
||||
return [self.input_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
memory_keys = values["memory"].dynamic_keys
|
||||
input_key = values["input_key"]
|
||||
if input_key in memory_keys:
|
||||
raise ValueError(
|
||||
f"The input key {input_key} was also found in the memory keys "
|
||||
f"({memory_keys}) - please provide keys that don't overlap."
|
||||
)
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = memory_keys + [input_key]
|
||||
if set(expected_keys) != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but got {memory_keys} as inputs from "
|
||||
f"memory, and {input_key} as the normal input key."
|
||||
)
|
||||
return values
|
@ -0,0 +1,86 @@
|
||||
"""Memory modules for conversation prompts."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.chains.base import Memory
|
||||
from langchain.chains.conversation.prompt import SUMMARY_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class ConversationBufferMemory(Memory, BaseModel):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
buffer: str = ""
|
||||
dynamic_key: str = "history" #: :meta private:
|
||||
|
||||
@property
|
||||
def dynamic_keys(self) -> List[str]:
|
||||
"""Will always return list of dynamic keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.dynamic_key]
|
||||
|
||||
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.dynamic_key: self.buffer}
|
||||
|
||||
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
|
||||
if len(prompt_input_keys) != 1:
|
||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
human = "Human: " + inputs[prompt_input_keys[0]]
|
||||
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
||||
self.buffer += "\n" + "\n".join([human, ai])
|
||||
|
||||
|
||||
class ConversationSummaryMemory(Memory, BaseModel):
|
||||
"""Conversation summarizer to memory."""
|
||||
|
||||
buffer: str = ""
|
||||
llm: LLM
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
dynamic_key: str = "history" #: :meta private:
|
||||
|
||||
@property
|
||||
def dynamic_keys(self) -> List[str]:
|
||||
"""Will always return list of dynamic keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.dynamic_key]
|
||||
|
||||
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.dynamic_key: self.buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
return values
|
||||
|
||||
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
|
||||
if len(prompt_input_keys) != 1:
|
||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
human = "Human: " + inputs[prompt_input_keys[0]]
|
||||
ai = "AI: " + list(outputs.values())[0]
|
||||
new_lines = "\n".join([human, ai])
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
|
@ -0,0 +1,37 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
AI:"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"], template=_DEFAULT_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
|
||||
|
||||
EXAMPLE
|
||||
Current summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
|
||||
|
||||
New lines of conversation:
|
||||
Human: Why do you think artificial intelligence is a force for good?
|
||||
AI: Because artificial intelligence will help humans reach their full potential.
|
||||
|
||||
New summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
|
||||
END OF EXAMPLE
|
||||
|
||||
Current summary:
|
||||
{summary}
|
||||
|
||||
New lines of conversation:
|
||||
{new_lines}
|
||||
|
||||
New summary:"""
|
||||
SUMMARY_PROMPT = PromptTemplate(
|
||||
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
|
||||
)
|
@ -1,4 +1,5 @@
|
||||
"""Wrappers on top of docstores."""
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.docstore.wikipedia import Wikipedia
|
||||
|
||||
__all__ = ["Wikipedia"]
|
||||
__all__ = ["InMemoryDocstore", "Wikipedia"]
|
||||
|
@ -0,0 +1,68 @@
|
||||
"""Test conversation chain and memory."""
|
||||
import pytest
|
||||
|
||||
from langchain.chains.base import Memory
|
||||
from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.chains.conversation.memory import (
|
||||
ConversationBufferMemory,
|
||||
ConversationSummaryMemory,
|
||||
)
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_conversation_chain_works() -> None:
|
||||
"""Test that conversation chain works in basic setting."""
|
||||
llm = FakeLLM()
|
||||
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
|
||||
memory = ConversationBufferMemory(dynamic_key="foo")
|
||||
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
|
||||
chain.run("foo")
|
||||
|
||||
|
||||
def test_conversation_chain_errors_bad_prompt() -> None:
|
||||
"""Test that conversation chain works in basic setting."""
|
||||
llm = FakeLLM()
|
||||
prompt = PromptTemplate(input_variables=[], template="nothing here")
|
||||
with pytest.raises(ValueError):
|
||||
ConversationChain(llm=llm, prompt=prompt)
|
||||
|
||||
|
||||
def test_conversation_chain_errors_bad_variable() -> None:
|
||||
"""Test that conversation chain works in basic setting."""
|
||||
llm = FakeLLM()
|
||||
prompt = PromptTemplate(input_variables=["foo"], template="{foo}")
|
||||
memory = ConversationBufferMemory(dynamic_key="foo")
|
||||
with pytest.raises(ValueError):
|
||||
ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="foo")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memory",
|
||||
[
|
||||
ConversationBufferMemory(dynamic_key="baz"),
|
||||
ConversationSummaryMemory(llm=FakeLLM(), dynamic_key="baz"),
|
||||
],
|
||||
)
|
||||
def test_conversation_memory(memory: Memory) -> None:
|
||||
"""Test basic conversation memory functionality."""
|
||||
# This is a good input because the input is not the same as baz.
|
||||
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||
# This is a good output because these is one variable.
|
||||
good_outputs = {"bar": "foo"}
|
||||
memory._save_context(good_inputs, good_outputs)
|
||||
# This is a bad input because there are two variables that aren't the same as baz.
|
||||
bad_inputs = {"foo": "bar", "foo1": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(bad_inputs, good_outputs)
|
||||
# This is a bad input because the only variable is the same as baz.
|
||||
bad_inputs = {"baz": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(bad_inputs, good_outputs)
|
||||
# This is a bad output because it is empty.
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(good_inputs, {})
|
||||
# This is a bad output because there are two keys.
|
||||
bad_outputs = {"foo": "bar", "foo1": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(good_inputs, bad_outputs)
|
Loading…
Reference in New Issue