Implemented appending arbitrary messages (#5293)

# Implemented appending arbitrary messages to the base chat message
history, the in-memory and cosmos ones.

<!--
Thank you for contributing to LangChain! Your PR will appear in our next
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.
-->

As discussed this is the alternative way instead of #4480, with a
add_message method added that takes a BaseMessage as input, so that the
user can control what is in the base message like kwargs.

<!-- Remove if not applicable -->

Fixes # (issue)

## Before submitting

<!-- If you're adding a new integration, include an integration test and
an example notebook showing its use! -->

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

@hwchase17

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
searx_updates
Eduard van Valkenburg 1 year ago committed by GitHub
parent d6fb25c439
commit ccb6238de1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,10 +3,8 @@ import logging
from typing import List from typing import List
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
) )
@ -143,13 +141,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in Cassandra""" """Append the message to the record in Cassandra"""
import uuid import uuid

@ -6,10 +6,8 @@ from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type from typing import TYPE_CHECKING, Any, List, Optional, Type
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
) )
@ -145,18 +143,13 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
if "messages" in item and len(item["messages"]) > 0: if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"]) self.messages = messages_from_dict(item["messages"])
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
"""Add a user message to the memory.""" """Add a self-created message to the store"""
self.upsert_messages(HumanMessage(content=message)) self.messages.append(message)
self.upsert_messages()
def add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory."""
self.upsert_messages(AIMessage(content=message))
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None: def upsert_messages(self) -> None:
"""Update the cosmosdb item.""" """Update the cosmosdb item."""
if new_message:
self.messages.append(new_message)
if not self._container: if not self._container:
raise ValueError("Container not initialized") raise ValueError("Container not initialized")
self._container.upsert_item( self._container.upsert_item(

@ -2,10 +2,8 @@ import logging
from typing import List from typing import List
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
@ -53,13 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items) messages = messages_from_dict(items)
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB""" """Append the message to the record in DynamoDB"""
from botocore.exceptions import ClientError from botocore.exceptions import ClientError

@ -4,10 +4,8 @@ from pathlib import Path
from typing import List from typing import List
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
) )
@ -36,13 +34,7 @@ class FileChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items) messages = messages_from_dict(items)
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in the local file""" """Append the message to the record in the local file"""
messages = messages_to_dict(self.messages) messages = messages_to_dict(self.messages)
messages.append(messages_to_dict([message])[0]) messages.append(messages_to_dict([message])[0])

@ -5,10 +5,8 @@ import logging
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
) )
@ -81,18 +79,12 @@ class FirestoreChatMessageHistory(BaseChatMessageHistory):
if "messages" in data and len(data["messages"]) > 0: if "messages" in data and len(data["messages"]) > 0:
self.messages = messages_from_dict(data["messages"]) self.messages = messages_from_dict(data["messages"])
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
"""Add a user message to the memory.""" self.messages.append(message)
self.upsert_messages(HumanMessage(content=message)) self.upsert_messages()
def add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory."""
self.upsert_messages(AIMessage(content=message))
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None: def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
"""Update the Firestore document.""" """Update the Firestore document."""
if new_message:
self.messages.append(new_message)
if not self._document: if not self._document:
raise ValueError("Document not initialized") raise ValueError("Document not initialized")
self._document.set( self._document.set(

@ -3,21 +3,17 @@ from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
) )
class ChatMessageHistory(BaseChatMessageHistory, BaseModel): class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
messages: List[BaseMessage] = [] messages: List[BaseMessage] = []
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.messages.append(HumanMessage(content=message)) """Add a self-created message to the store"""
self.messages.append(message)
def add_ai_message(self, message: str) -> None:
self.messages.append(AIMessage(content=message))
def clear(self) -> None: def clear(self) -> None:
self.messages = [] self.messages = []

@ -5,10 +5,8 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
) )
@ -143,23 +141,7 @@ class MomentoChatMessageHistory(BaseChatMessageHistory):
else: else:
raise Exception(f"Unexpected response: {fetch_response}") raise Exception(f"Unexpected response: {fetch_response}")
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
"""Store a user message in the cache.
Args:
message (str): The message to store.
"""
self.__add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Store an AI message in the cache.
Args:
message (str): The message to store.
"""
self.__add_message(AIMessage(content=message))
def __add_message(self, message: BaseMessage) -> None:
"""Store a message in the cache. """Store a message in the cache.
Args: Args:

@ -3,10 +3,8 @@ import logging
from typing import List from typing import List
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
) )
@ -68,13 +66,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items) messages = messages_from_dict(items)
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in MongoDB""" """Append the message to the record in MongoDB"""
from pymongo import errors from pymongo import errors

@ -3,10 +3,8 @@ import logging
from typing import List from typing import List
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
) )
@ -55,13 +53,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items) messages = messages_from_dict(items)
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL""" """Append the message to the record in PostgreSQL"""
from psycopg import sql from psycopg import sql

@ -3,10 +3,8 @@ import logging
from typing import List, Optional from typing import List, Optional
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
) )
@ -52,13 +50,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items) messages = messages_from_dict(items)
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis""" """Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message))) self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
if self.ttl: if self.ttl:

@ -7,10 +7,8 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from langchain.schema import ( from langchain.schema import (
AIMessage,
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage,
_message_to_dict, _message_to_dict,
messages_from_dict, messages_from_dict,
) )
@ -61,13 +59,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items) messages = messages_from_dict(items)
return messages return messages
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in db""" """Append the message to the record in db"""
with self.Session() as session: with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message)) jsonstr = json.dumps(_message_to_dict(message))

@ -116,13 +116,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
return None return None
return zep_memory return zep_memory
def add_user_message(self, message: str) -> None: def add_message(self, message: BaseMessage) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the Zep memory history""" """Append the message to the Zep memory history"""
from zep_python import Memory, Message from zep_python import Memory, Message

@ -234,18 +234,11 @@ class BaseChatMessageHistory(ABC):
messages = json.loads(f.read()) messages = json.loads(f.read())
return messages_from_dict(messages) return messages_from_dict(messages)
def add_user_message(self, message: str): def add_message(self, message: BaseMessage) -> None:
message_ = HumanMessage(content=message) messages = self.messages.append(_message_to_dict(message))
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f: with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages) json.dump(f, messages)
def add_ai_message(self, message: str):
message_ = AIMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self): def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f: with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]") f.write("[]")
@ -253,13 +246,17 @@ class BaseChatMessageHistory(ABC):
messages: List[BaseMessage] messages: List[BaseMessage]
@abstractmethod
def add_user_message(self, message: str) -> None: def add_user_message(self, message: str) -> None:
"""Add a user message to the store""" """Add a user message to the store"""
self.add_message(HumanMessage(content=message))
@abstractmethod
def add_ai_message(self, message: str) -> None: def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add an AI message to the store"""
self.add_message(AIMessage(content=message))
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
raise NotImplementedError
@abstractmethod @abstractmethod
def clear(self) -> None: def clear(self) -> None:

@ -60,7 +60,7 @@ def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory)
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None: def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.append(AIMessage(content="test message")) zep_chat.add_message(AIMessage(content="test message"))
zep_chat.zep_client.add_memory.assert_called_once() # type: ignore zep_chat.zep_client.add_memory.assert_called_once() # type: ignore

Loading…
Cancel
Save