Add revision identifier to run_on_dataset (#16167)

Allow specifying revision identifier for better project versioning
pull/14446/head^2
SN 5 months ago committed by GitHub
parent 5d8c147332
commit 7d444724d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -657,6 +657,7 @@ async def _arun_llm(
tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[str, BaseMessage]:
"""Asynchronously run the language model.
@ -682,7 +683,9 @@ async def _arun_llm(
):
return await llm.ainvoke(
prompt_or_messages,
config=RunnableConfig(callbacks=callbacks, tags=tags or []),
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
raise InputFormatError(
@ -695,12 +698,18 @@ async def _arun_llm(
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
prompt, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
prompt,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.ainvoke(
messages, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
messages,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
return llm_output
@ -712,6 +721,7 @@ async def _arun_chain(
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[dict, str]:
"""Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
@ -723,10 +733,15 @@ async def _arun_chain(
):
val = next(iter(inputs_.values()))
output = await chain.ainvoke(
val, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
val,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
)
output = await chain.ainvoke(inputs_, config=runnable_config)
return output
@ -762,6 +777,7 @@ async def _arun_llm_or_chain(
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
else:
chain = llm_or_chain_factory()
@ -771,6 +787,7 @@ async def _arun_llm_or_chain(
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
result = output
except Exception as e:
@ -793,6 +810,7 @@ def _run_llm(
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[str, BaseMessage]:
"""
Run the language model on the example.
@ -819,7 +837,9 @@ def _run_llm(
):
llm_output: Union[str, BaseMessage] = llm.invoke(
prompt_or_messages,
config=RunnableConfig(callbacks=callbacks, tags=tags or []),
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
raise InputFormatError(
@ -831,12 +851,16 @@ def _run_llm(
try:
llm_prompts = _get_prompt(inputs)
llm_output = llm.invoke(
llm_prompts, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
llm_prompts,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
except InputFormatError:
llm_messages = _get_messages(inputs)
llm_output = llm.invoke(
llm_messages, config=RunnableConfig(callbacks=callbacks)
llm_messages,
config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}),
)
return llm_output
@ -848,6 +872,7 @@ def _run_chain(
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[Dict, str]:
"""Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
@ -859,10 +884,15 @@ def _run_chain(
):
val = next(iter(inputs_.values()))
output = chain.invoke(
val, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
val,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
)
output = chain.invoke(inputs_, config=runnable_config)
return output
@ -899,6 +929,7 @@ def _run_llm_or_chain(
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
else:
chain = llm_or_chain_factory()
@ -908,6 +939,7 @@ def _run_llm_or_chain(
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
result = output
except Exception as e:
@ -1083,8 +1115,13 @@ class _DatasetRunContainer:
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None,
revision_id: Optional[str] = None,
) -> _DatasetRunContainer:
project_name = project_name or name_generation.random_name()
if revision_id:
if not project_metadata:
project_metadata = {}
project_metadata.update({"revision_id": revision_id})
wrapped_model, project, dataset, examples = _prepare_eval_run(
client,
dataset_name,
@ -1121,6 +1158,7 @@ class _DatasetRunContainer:
],
tags=tags,
max_concurrency=concurrency_level,
metadata={"revision_id": revision_id} if revision_id else {},
)
for example in examples
]
@ -1183,6 +1221,7 @@ async def arun_on_dataset(
project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None)
@ -1208,6 +1247,7 @@ async def arun_on_dataset(
input_mapper,
concurrency_level,
project_metadata=project_metadata,
revision_id=revision_id,
)
batch_results = await runnable_utils.gather_with_concurrency(
container.configs[0].get("max_concurrency"),
@ -1235,6 +1275,7 @@ def run_on_dataset(
project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None)
@ -1260,6 +1301,7 @@ def run_on_dataset(
input_mapper,
concurrency_level,
project_metadata=project_metadata,
revision_id=revision_id,
)
if concurrency_level == 0:
batch_results = [
@ -1309,6 +1351,8 @@ Args:
log feedback and run traces.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
revision_id: Optional revision identifier to assign this test run to
track the performance of different versions of your system.
Returns:
A dictionary containing the run's project name and the resulting model outputs.

Loading…
Cancel
Save