From 97e7dc15023f256090b1fe27a30c54d5e489701b Mon Sep 17 00:00:00 2001 From: Jonas Nelle Date: Thu, 11 May 2023 13:24:50 -0400 Subject: [PATCH] Make BaseStringMessagePromptTemplate.from_template return type generic (#4523) # Make BaseStringMessagePromptTemplate.from_template return type generic I use mypy to check type on my code that uses langchain. Currently after I load a prompt and convert it to a system prompt I have to explicitly cast it which is quite ugly (and not necessary): ``` prompt_template = load_prompt("prompt.yaml") system_prompt_template = cast( SystemMessagePromptTemplate, SystemMessagePromptTemplate.from_template(prompt_template.template), ) ``` With this PR, the code would simply be: ``` prompt_template = load_prompt("prompt.yaml") system_prompt_template = SystemMessagePromptTemplate.from_template(prompt_template.template) ``` Given how much langchain uses inheritance, I think this type hinting could be applied in a bunch more places, e.g. load_prompt also return a `FewShotPromptTemplate` or a `PromptTemplate` but without typing the type checkers aren't able to infer that. Let me know if you agree and I can take a look at implementing that as well. @hwchase17 - project lead DataLoaders - @eyurtsev --- langchain/prompts/chat.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 251b8e9274..9096b08d34 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, List, Sequence, Tuple, Type, Union +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from pydantic import BaseModel, Field @@ -58,12 +58,19 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): return [self.variable_name] +MessagePromptTemplateT = TypeVar( + "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" +) + + class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): prompt: StringPromptTemplate additional_kwargs: dict = Field(default_factory=dict) @classmethod - def from_template(cls, template: str, **kwargs: Any) -> BaseMessagePromptTemplate: + def from_template( + cls: Type[MessagePromptTemplateT], template: str, **kwargs: Any + ) -> MessagePromptTemplateT: prompt = PromptTemplate.from_template(template) return cls(prompt=prompt, **kwargs)