Fix loading of ImagePromptTemplate (#16868)

We didn't override the namespace of the ImagePromptTemplate, so it is
listed as being in langchain.schema

This updates the mapping to let the loader deserialize.

Alternatively, we could make a slight breaking change and update the
namespace of the ImagePromptTemplate since we haven't broadly
publicized/documented it yet..
This commit is contained in:
William FH 2024-02-01 17:54:04 -08:00 committed by GitHub
parent 6fc2835255
commit 131c043864
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 141 additions and 2 deletions

View File

@ -115,6 +115,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"chat",
"SystemMessagePromptTemplate",
),
("langchain", "prompts", "image", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
("langchain", "schema", "agent", "AgentActionMessageLog"): (
"langchain_core",
"agents",
@ -510,6 +516,12 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"system",
"SystemMessage",
),
("langchain", "schema", "prompt_template", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
}
# Needed for backwards compatibility for a few versions where we serialized

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
@ -30,6 +30,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
"""Return the prompt type key."""
return "image-prompt"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "image"]
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return ImagePromptValue(image_url=self.format(**kwargs))

View File

@ -0,0 +1,109 @@
import json
from langchain_core.load import dump, loads
from langchain_core.prompts import ChatPromptTemplate
def test_image_prompt_template_deserializable() -> None:
"""Test that the image prompt template is serializable."""
loads(
dump.dumps(
ChatPromptTemplate.from_messages(
[("system", [{"type": "image", "image_url": "{img}"}])]
)
)
)
def test_image_prompt_template_deserializable_old() -> None:
"""Test that the image prompt template is serializable."""
loads(
json.dumps(
{
"lc": 1,
"type": "constructor",
"id": ["langchain", "prompts", "chat", "ChatPromptTemplate"],
"kwargs": {
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"SystemMessagePromptTemplate",
],
"kwargs": {
"prompt": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate",
],
"kwargs": {
"template": "Foo",
"input_variables": [],
"template_format": "f-string",
"partial_variables": {},
},
}
]
},
},
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate",
],
"kwargs": {
"prompt": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"image",
"ImagePromptTemplate",
],
"kwargs": {
"template": {
"url": "data:image/png;base64,{img}"
},
"input_variables": ["img"],
},
},
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate",
],
"kwargs": {
"template": "{input}",
"input_variables": ["input"],
"template_format": "f-string",
"partial_variables": {},
},
},
]
},
},
],
"input_variables": ["img", "input"],
},
}
)
)

View File

@ -40,8 +40,21 @@ def import_all_modules(package_name: str) -> dict:
def test_serializable_mapping() -> None:
# This should have had a different namespace, as it was never
# exported from the langchain module, but we keep for whoever has
# already serialized it.
to_skip = {
("langchain", "prompts", "image", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
}
serializable_modules = import_all_modules("langchain")
missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules)
missing = set(SERIALIZABLE_MAPPING).difference(
set(serializable_modules).union(to_skip)
)
assert missing == set()
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
assert extra == set()