diff --git a/libs/langchain/langchain/schema/messages.py b/libs/langchain/langchain/schema/messages.py index c926bcc8f6..1722602be3 100644 --- a/libs/langchain/langchain/schema/messages.py +++ b/libs/langchain/langchain/schema/messages.py @@ -1,12 +1,15 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict, List, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Sequence from pydantic import Field from langchain.load.serializable import Serializable +if TYPE_CHECKING: + from langchain.prompts.chat import ChatPromptTemplate + def get_buffer_string( messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" @@ -77,6 +80,12 @@ class BaseMessage(Serializable): """Whether this class is LangChain serializable.""" return True + def __add__(self, other: Any) -> ChatPromptTemplate: + from langchain.prompts.chat import ChatPromptTemplate + + prompt = ChatPromptTemplate(messages=[self]) + return prompt + other + class BaseMessageChunk(BaseMessage): def _merge_kwargs_dict( @@ -102,7 +111,7 @@ class BaseMessageChunk(BaseMessage): ) return merged - def __add__(self, other: Any) -> BaseMessageChunk: + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, BaseMessageChunk): # If both are (subclasses of) BaseMessageChunk, # concat into a single BaseMessageChunk