Samantha/add conversation chain (#166)

Add MemoryChain and ConversationChain as chains that take a docstore in
addition to the prompt, and use the docstore to stuff context into the
prompt. This can be used to have an ongoing conversation with a chatbot.

Probably needs a bit of refactoring for code quality

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
harrison/flexible_model_args
Samantha Whitmore 2 years ago committed by GitHub
parent 4334ffa6f9
commit a408ed3ea3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,283 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ae046bff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
"\n",
"Current conversation:\n",
"\n",
"Human: Hi there!\n",
"AI:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' Hello! How are you today?'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain import OpenAI, ConversationChain\n",
"from langchain.chains.conversation.memory import ConversationSummaryMemory\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"conversation = ConversationChain(llm=llm, verbose=True)\n",
"\n",
"conversation.predict(input=\"Hi there!\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d8e2a6ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
"\n",
"Current conversation:\n",
"\n",
"Human: Hi there!\n",
"AI: Hello! How are you today?\n",
"Human: I'm doing well! Just having a conversation with an AI.\n",
"AI:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\" That's great! What would you like to talk about?\""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conversation.predict(input=\"I'm doing well! Just having a conversation with an AI.\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "15eda316",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
"\n",
"Current conversation:\n",
"\n",
"Human: Hi there!\n",
"AI: Hello! How are you today?\n",
"Human: I'm doing well! Just having a conversation with an AI.\n",
"AI: That's great! What would you like to talk about?\n",
"Human: Tell me about yourself.\n",
"AI:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' I am an AI created to provide information and support to humans. I enjoy learning and exploring new things.'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conversation.predict(input=\"Tell me about yourself.\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b7274f2c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
"\n",
"Current conversation:\n",
"\n",
"Human: Hi, what's up?\n",
"AI:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nI'm doing well, thank you for asking. I'm currently working on a project that I'm really excited about.\""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conversation_with_summary = ConversationChain(llm=llm, memory=ConversationSummaryMemory(llm=OpenAI()), verbose=True)\n",
"conversation_with_summary.predict(input=\"Hi, what's up?\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a6b6b88f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
"\n",
"Current conversation:\n",
"\n",
"The human greets the AI and asks how it is doing. The AI responds that it is doing well and is currently working on a project that it is excited about.\n",
"Human: Tell me more about it!\n",
"AI:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nI'm working on a project that involves helping people to better understand and use artificial intelligence. I'm really excited about it because I think it has the potential to make a big difference in people's lives.\""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conversation_with_summary.predict(input=\"Tell me more about it!\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dad869fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
"\n",
"Current conversation:\n",
"\n",
"\n",
"The human greets the AI and asks how it is doing. The AI responds that it is doing well and is currently working on a project that it is excited about - a project that involves helping people to better understand and use artificial intelligence.\n",
"Human: Very cool -- what is the scope of the project?\n",
"AI:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nThe project is still in the early stages, but the goal is to create a resource that will help people to understand artificial intelligence and how to use it effectively.'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conversation_with_summary.predict(input=\"Very cool -- what is the scope of the project?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0eb11bd0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -83,7 +83,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.10.4"
}
},
"nbformat": 4,

@ -7,6 +7,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.chains import (
ConversationChain,
LLMChain,
LLMMathChain,
PythonChain,
@ -14,7 +15,7 @@ from langchain.chains import (
SQLDatabaseChain,
VectorDBQA,
)
from langchain.docstore import Wikipedia
from langchain.docstore import InMemoryDocstore, Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompts import (
BasePromptTemplate,
@ -46,4 +47,6 @@ __all__ = [
"MRKLChain",
"VectorDBQA",
"ElasticVectorSearch",
"InMemoryDocstore",
"ConversationChain",
]

@ -1,4 +1,5 @@
"""Chains are easily reusable components which can be linked together."""
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.python import PythonChain
@ -16,4 +17,5 @@ __all__ = [
"VectorDBQA",
"SequentialChain",
"SimpleSequentialChain",
"ConversationChain",
]

@ -1,13 +1,38 @@
"""Base interface that all chains should implement."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Extra
class Memory(BaseModel, ABC):
"""Base interface for memory in chains."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
@abstractmethod
def dynamic_keys(self) -> List[str]:
"""Input keys this memory class will load dynamically."""
@abstractmethod
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[Memory] = None
verbose: bool = False
"""Whether to print out response text."""
@ -51,6 +76,9 @@ class Chain(BaseModel, ABC):
chain will be returned. Defaults to False.
"""
if self.memory is not None:
external_context = self.memory._load_dynamic_keys(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
print("\n\n\033[1m> Entering new chain...\033[0m")
@ -58,6 +86,8 @@ class Chain(BaseModel, ABC):
if self.verbose:
print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs)
if self.memory is not None:
self.memory._save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:

@ -0,0 +1 @@
"""Chain that carries on a conversation from a prompt plus history."""

@ -0,0 +1,61 @@
"""Chain that carries on a conversation and calls an LLM."""
from typing import Dict, List
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.base import Memory
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
class ConversationChain(LLMChain, BaseModel):
"""Chain to have a conversation and load context from memory.
Example:
.. code-block:: python
from langchain import ConversationChain, OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
memory: Memory = Field(default_factory=ConversationBufferMemory)
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
input_key: str = "input" #: :meta private:
output_key: str = "response" #: :meta private:
buffer: str = "" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
memory_keys = values["memory"].dynamic_keys
input_key = values["input_key"]
if input_key in memory_keys:
raise ValueError(
f"The input key {input_key} was also found in the memory keys "
f"({memory_keys}) - please provide keys that don't overlap."
)
prompt_variables = values["prompt"].input_variables
expected_keys = memory_keys + [input_key]
if set(expected_keys) != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but got {memory_keys} as inputs from "
f"memory, and {input_key} as the normal input key."
)
return values

@ -0,0 +1,86 @@
"""Memory modules for conversation prompts."""
from typing import Any, Dict, List
from pydantic import BaseModel, root_validator
from langchain.chains.base import Memory
from langchain.chains.conversation.prompt import SUMMARY_PROMPT
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
class ConversationBufferMemory(Memory, BaseModel):
"""Buffer for storing conversation memory."""
buffer: str = ""
dynamic_key: str = "history" #: :meta private:
@property
def dynamic_keys(self) -> List[str]:
"""Will always return list of dynamic keys.
:meta private:
"""
return [self.dynamic_key]
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.dynamic_key: self.buffer}
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
human = "Human: " + inputs[prompt_input_keys[0]]
ai = "AI: " + outputs[list(outputs.keys())[0]]
self.buffer += "\n" + "\n".join([human, ai])
class ConversationSummaryMemory(Memory, BaseModel):
"""Conversation summarizer to memory."""
buffer: str = ""
llm: LLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
dynamic_key: str = "history" #: :meta private:
@property
def dynamic_keys(self) -> List[str]:
"""Will always return list of dynamic keys.
:meta private:
"""
return [self.dynamic_key]
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.dynamic_key: self.buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
human = "Human: " + inputs[prompt_input_keys[0]]
ai = "AI: " + list(outputs.values())[0]
new_lines = "\n".join([human, ai])
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)

@ -0,0 +1,37 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
_DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{history}
Human: {input}
AI:"""
PROMPT = PromptTemplate(
input_variables=["history", "input"], template=_DEFAULT_TEMPLATE
)
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
EXAMPLE
Current summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
New lines of conversation:
Human: Why do you think artificial intelligence is a force for good?
AI: Because artificial intelligence will help humans reach their full potential.
New summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
END OF EXAMPLE
Current summary:
{summary}
New lines of conversation:
{new_lines}
New summary:"""
SUMMARY_PROMPT = PromptTemplate(
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
)

@ -15,7 +15,7 @@ class LLMChain(Chain, BaseModel):
Example:
.. code-block:: python
from langchain import LLMChain, OpenAI, Prompt
from langchain import LLMChain, OpenAI, PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template

@ -1,4 +1,5 @@
"""Wrappers on top of docstores."""
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.docstore.wikipedia import Wikipedia
__all__ = ["Wikipedia"]
__all__ = ["InMemoryDocstore", "Wikipedia"]

@ -0,0 +1,68 @@
"""Test conversation chain and memory."""
import pytest
from langchain.chains.base import Memory
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.conversation.memory import (
ConversationBufferMemory,
ConversationSummaryMemory,
)
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_conversation_chain_works() -> None:
"""Test that conversation chain works in basic setting."""
llm = FakeLLM()
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
memory = ConversationBufferMemory(dynamic_key="foo")
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
chain.run("foo")
def test_conversation_chain_errors_bad_prompt() -> None:
"""Test that conversation chain works in basic setting."""
llm = FakeLLM()
prompt = PromptTemplate(input_variables=[], template="nothing here")
with pytest.raises(ValueError):
ConversationChain(llm=llm, prompt=prompt)
def test_conversation_chain_errors_bad_variable() -> None:
"""Test that conversation chain works in basic setting."""
llm = FakeLLM()
prompt = PromptTemplate(input_variables=["foo"], template="{foo}")
memory = ConversationBufferMemory(dynamic_key="foo")
with pytest.raises(ValueError):
ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="foo")
@pytest.mark.parametrize(
"memory",
[
ConversationBufferMemory(dynamic_key="baz"),
ConversationSummaryMemory(llm=FakeLLM(), dynamic_key="baz"),
],
)
def test_conversation_memory(memory: Memory) -> None:
"""Test basic conversation memory functionality."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}
# This is a good output because these is one variable.
good_outputs = {"bar": "foo"}
memory._save_context(good_inputs, good_outputs)
# This is a bad input because there are two variables that aren't the same as baz.
bad_inputs = {"foo": "bar", "foo1": "bar"}
with pytest.raises(ValueError):
memory._save_context(bad_inputs, good_outputs)
# This is a bad input because the only variable is the same as baz.
bad_inputs = {"baz": "bar"}
with pytest.raises(ValueError):
memory._save_context(bad_inputs, good_outputs)
# This is a bad output because it is empty.
with pytest.raises(ValueError):
memory._save_context(good_inputs, {})
# This is a bad output because there are two keys.
bad_outputs = {"foo": "bar", "foo1": "bar"}
with pytest.raises(ValueError):
memory._save_context(good_inputs, bad_outputs)
Loading…
Cancel
Save