diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 251b8e9274..9096b08d34 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, List, Sequence, Tuple, Type, Union +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from pydantic import BaseModel, Field @@ -58,12 +58,19 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): return [self.variable_name] +MessagePromptTemplateT = TypeVar( + "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" +) + + class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): prompt: StringPromptTemplate additional_kwargs: dict = Field(default_factory=dict) @classmethod - def from_template(cls, template: str, **kwargs: Any) -> BaseMessagePromptTemplate: + def from_template( + cls: Type[MessagePromptTemplateT], template: str, **kwargs: Any + ) -> MessagePromptTemplateT: prompt = PromptTemplate.from_template(template) return cls(prompt=prompt, **kwargs)