From e0a13e93550e06489846f68997348ed3949f9d0a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 10 Apr 2023 21:15:42 -0700 Subject: [PATCH] Harrison/postgres (#2691) Co-authored-by: Ankit Jain --- .../postgres_chat_message_history.ipynb | 62 +++++++++++++ langchain/memory/__init__.py | 2 + .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/postgres.py | 86 +++++++++++++++++++ 4 files changed, 152 insertions(+) create mode 100644 docs/modules/memory/examples/postgres_chat_message_history.ipynb create mode 100644 langchain/memory/chat_message_histories/postgres.py diff --git a/docs/modules/memory/examples/postgres_chat_message_history.ipynb b/docs/modules/memory/examples/postgres_chat_message_history.ipynb new file mode 100644 index 00000000..be3705f2 --- /dev/null +++ b/docs/modules/memory/examples/postgres_chat_message_history.ipynb @@ -0,0 +1,62 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# Postgres Chat Message History\n", + "\n", + "This notebook goes over how to use Postgres to store chat message history." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import PostgresChatMessageHistory\n", + "\n", + "history = PostgresChatMessageHistory(connection_string=\"postgresql://postgres:mypassword@localhost/chat_history\", session_id=\"foo\")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64fc465e", + "metadata": {}, + "outputs": [], + "source": [ + "history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index aa56cc6c..b085a336 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -5,6 +5,7 @@ from langchain.memory.buffer import ( from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory +from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.combined import CombinedMemory from langchain.memory.entity import ( @@ -36,4 +37,5 @@ __all__ = [ "ConversationTokenBufferMemory", "RedisChatMessageHistory", "DynamoDBChatMessageHistory", + "PostgresChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index 199844ef..ee8a6222 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -1,7 +1,9 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory +from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory __all__ = [ "DynamoDBChatMessageHistory", "RedisChatMessageHistory", + "PostgresChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/postgres.py b/langchain/memory/chat_message_histories/postgres.py new file mode 100644 index 00000000..ddca8444 --- /dev/null +++ b/langchain/memory/chat_message_histories/postgres.py @@ -0,0 +1,86 @@ +import json +import logging +from typing import List + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, +) + +logger = logging.getLogger(__name__) + +DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_history" + + +class PostgresChatMessageHistory(BaseChatMessageHistory): + def __init__( + self, + session_id: str, + connection_string: str = DEFAULT_CONNECTION_STRING, + table_name: str = "message_store", + ): + import psycopg + from psycopg.rows import dict_row + + try: + self.connection = psycopg.connect(connection_string) + self.cursor = self.connection.cursor(row_factory=dict_row) + except psycopg.OperationalError as error: + logger.error(error) + + self.session_id = session_id + self.table_name = table_name + + self._create_table_if_not_exists() + + def _create_table_if_not_exists(self) -> None: + create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL, + message JSONB NOT NULL + );""" + self.cursor.execute(create_table_query) + self.connection.commit() + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from PostgreSQL""" + query = f"SELECT message FROM {self.table_name} WHERE session_id = %s;" + self.cursor.execute(query, (self.session_id,)) + items = [record["message"] for record in self.cursor.fetchall()] + messages = messages_from_dict(items) + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in PostgreSQL""" + from psycopg import sql + + query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format( + sql.Identifier(self.table_name) + ) + self.cursor.execute( + query, (self.session_id, json.dumps(_message_to_dict(message))) + ) + self.connection.commit() + + def clear(self) -> None: + """Clear session memory from PostgreSQL""" + query = f"DELETE FROM {self.table_name} WHERE session_id = %s;" + self.cursor.execute(query, (self.session_id,)) + self.connection.commit() + + def __del__(self) -> None: + if self.cursor: + self.cursor.close() + if self.connection: + self.connection.close()