[core] prompt changes (#15186)

change it to pass all variables through all the way in invoke
pull/15196/head
Harrison Chase 6 months ago committed by GitHub
parent ccf9c8e0be
commit 4ad77f777e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,20 +73,19 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
)
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
try:
input_dict = {key: inner_input[key] for key in self.input_variables}
except TypeError as e:
if not isinstance(inner_input, dict):
raise TypeError(
f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}."
) from e
except KeyError as e:
)
missing = set(self.input_variables).difference(inner_input)
if missing:
raise KeyError(
f"Input to {self.__class__.__name__} is missing variable {e}. "
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
) from e
return self.format_prompt(**input_dict)
)
return self.format_prompt(**inner_input)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
@ -100,7 +99,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
"""Create Prompt Value."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:

@ -133,7 +133,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
Returns:
List of input variable names.
"""
return [self.variable_name]
return [self.variable_name] if not self.optional else []
MessagePromptTemplateT = TypeVar(
@ -611,12 +611,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
rel_params = {
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
message = message_template.format_messages(**kwargs)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")

@ -43,6 +43,8 @@ class LogEntry(TypedDict):
streamed_output_str: List[str]
"""List of LLM tokens streamed by this run, if applicable."""
streamed_output: List[Any]
"""List of output chunks streamed by this run, if available."""
final_output: Optional[Any]
"""Final output of this run.
Only available after the run has finished successfully."""
@ -242,6 +244,7 @@ class LogStreamCallbackHandler(BaseTracer):
tags=run.tags or [],
metadata=(run.extra or {}).get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output=[],
streamed_output_str=[],
final_output=None,
end_time=None,
@ -298,6 +301,13 @@ class LogStreamCallbackHandler(BaseTracer):
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
}
},
{
"op": "add",
"path": f"/logs/{index}/streamed_output/-",
"value": chunk.message
if isinstance(chunk, ChatGenerationChunk)
else token,
},
)
)

@ -1,22 +1,11 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
from typing import Any, List, Mapping, Sequence
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:

@ -96,26 +96,6 @@ def test_prompt_missing_input_variables() -> None:
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_few_shot_functionality() -> None:
"""Test that few shot works with examples."""
prefix = "This is a test about {content}."

@ -53,19 +53,6 @@ def test_prompt_empty_input_variable() -> None:
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
template = "This is a {foo} test."

@ -2054,6 +2054,7 @@ async def test_prompt_with_llm(
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:1"],
"type": "prompt",
@ -2087,6 +2088,7 @@ async def test_prompt_with_llm(
"metadata": {},
"name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:2"],
"type": "llm",

@ -18,8 +18,9 @@ def test_does_not_allow_args() -> None:
formatter.format(template, "good")
def test_does_not_allow_extra_kwargs() -> None:
"""Test formatting does not allow extra keyword arguments."""
def test_allows_extra_kwargs() -> None:
"""Test formatting allows extra keyword arguments."""
template = "This is a {foo} test."
with pytest.raises(KeyError):
formatter.format(template, foo="good", bar="oops")
output = formatter.format(template, foo="good", bar="oops")
expected_output = "This is a good test."
assert output == expected_output

Loading…
Cancel
Save