Make BaseStringMessagePromptTemplate.from_template return type generic (#4523)

# Make BaseStringMessagePromptTemplate.from_template return type generic

I use mypy to check type on my code that uses langchain. Currently after
I load a prompt and convert it to a system prompt I have to explicitly
cast it which is quite ugly (and not necessary):
```
prompt_template = load_prompt("prompt.yaml")
system_prompt_template = cast(
    SystemMessagePromptTemplate,
    SystemMessagePromptTemplate.from_template(prompt_template.template),
)
```

With this PR, the code would simply be: 
```
prompt_template = load_prompt("prompt.yaml")
system_prompt_template = SystemMessagePromptTemplate.from_template(prompt_template.template)
```

Given how much langchain uses inheritance, I think this type hinting
could be applied in a bunch more places, e.g. load_prompt also return a
`FewShotPromptTemplate` or a `PromptTemplate` but without typing the
type checkers aren't able to infer that. Let me know if you agree and I
can take a look at implementing that as well.

        @hwchase17 - project lead

        DataLoaders
        - @eyurtsev
pull/4499/head^2
Jonas Nelle 1 year ago committed by GitHub
parent 446b60d803
commit 97e7dc1502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, Union
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import BaseModel, Field
@ -58,12 +58,19 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
return [self.variable_name]
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
prompt: StringPromptTemplate
additional_kwargs: dict = Field(default_factory=dict)
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> BaseMessagePromptTemplate:
def from_template(
cls: Type[MessagePromptTemplateT], template: str, **kwargs: Any
) -> MessagePromptTemplateT:
prompt = PromptTemplate.from_template(template)
return cls(prompt=prompt, **kwargs)

Loading…
Cancel
Save