mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
core[patch]: extract input variables for path
and detail
keys in order to format an ImagePromptTemplate
(#22613)
- Description: Add support for `path` and `detail` keys in `ImagePromptTemplate`. Previously, only variables associated with the `url` key were considered. This PR allows for the inclusion of a local image path and a detail parameter as input to the format method. - Issues: - fixes #20820 - related to #22024 - Dependencies: None - Twitter handle: @DeschampsTho5 --------- Co-authored-by: tdeschamps <tdeschamps@kameleoon.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
This commit is contained in:
parent
a4798802ef
commit
39b19cf764
@ -473,6 +473,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
)
|
||||
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
|
||||
input_variables = []
|
||||
if isinstance(img_template, str):
|
||||
vars = get_template_variables(img_template, "f-string")
|
||||
if vars:
|
||||
@ -483,20 +484,19 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
f"\nFrom: {tmpl}"
|
||||
)
|
||||
input_variables = [vars[0]]
|
||||
else:
|
||||
input_variables = None
|
||||
img_template = {"url": img_template}
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables, template=img_template
|
||||
)
|
||||
elif isinstance(img_template, dict):
|
||||
img_template = dict(img_template)
|
||||
if "url" in img_template:
|
||||
input_variables = get_template_variables(
|
||||
img_template["url"], "f-string"
|
||||
)
|
||||
else:
|
||||
input_variables = None
|
||||
for key in ["url", "path", "detail"]:
|
||||
if key in img_template:
|
||||
input_variables.extend(
|
||||
get_template_variables(
|
||||
img_template[key], "f-string"
|
||||
)
|
||||
)
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables, template=img_template
|
||||
)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
|
||||
@ -559,6 +561,7 @@ async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
||||
|
||||
|
||||
async def test_chat_tmpl_from_messages_multipart_image() -> None:
|
||||
"""Test multipart image URL formatting."""
|
||||
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
@ -641,6 +644,65 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
|
||||
assert messages == expected
|
||||
|
||||
|
||||
async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
|
||||
"""Verify that we can pass `path` for an image as a variable."""
|
||||
in_mem = "base64mem"
|
||||
in_file_data = "base64file01"
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
|
||||
temp_file.write(base64.b64decode(in_file_data))
|
||||
temp_file.flush()
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
(
|
||||
"human",
|
||||
[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/jpeg;base64,{in_mem}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"path": "{file_path}"},
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{in_mem}"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{in_file_data}"},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
messages = template.format_messages(
|
||||
name="R2D2",
|
||||
in_mem=in_mem,
|
||||
file_path=temp_file.name,
|
||||
)
|
||||
assert messages == expected
|
||||
|
||||
messages = await template.aformat_messages(
|
||||
name="R2D2",
|
||||
in_mem=in_mem,
|
||||
file_path=temp_file.name,
|
||||
)
|
||||
assert messages == expected
|
||||
|
||||
|
||||
def test_messages_placeholder() -> None:
|
||||
prompt = MessagesPlaceholder("history")
|
||||
with pytest.raises(KeyError):
|
||||
|
Loading…
Reference in New Issue
Block a user