Add concurrency support for run_on_dataset (#8841)

Long-term, would be better to use the lower-level batch() method(s) but
it may take me a bit longer to clean up. This unblocks in the meantime,
though it may fail when the evaluated chain raises a
`NotImplementedError` for a corresponding async method
pull/8893/head
William FH 1 year ago committed by GitHub
parent fc2f450f2d
commit 91be7eee66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1278,6 +1278,27 @@ async def arun_on_dataset(
}
def _handle_coroutine(coro: Coroutine) -> Any:
"""
Handles a coroutine from a sync context.
Args:
coro (asyncio.coroutine): The coroutine to be handled.
Returns:
any: The result of the executed coroutine.
"""
# Check if there's a running event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError: # No event loop
return asyncio.run(coro)
if loop.is_running():
return loop.create_task(coro)
else:
return asyncio.run(coro)
def run_on_dataset(
client: Client,
dataset_name: str,
@ -1285,6 +1306,7 @@ def run_on_dataset(
*,
evaluation: Optional[RunEvalConfig] = None,
num_repetitions: int = 1,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
@ -1303,6 +1325,7 @@ def run_on_dataset(
independent calls on each example without carrying over state.
evaluation: Configuration for evaluators to run on the
results of the chain
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
@ -1403,18 +1426,35 @@ def run_on_dataset(
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
results = _run_on_examples(
client,
examples,
llm_or_chain_factory,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
)
if concurrency_level in (0, 1):
results = _run_on_examples(
client,
examples,
llm_or_chain_factory,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
)
else:
# TODO: Use runnables and the batch method
coro = _arun_on_examples(
client,
examples,
llm_or_chain_factory,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
)
results = _handle_coroutine(coro)
return {
"project_name": project_name,
"results": results,

Loading…
Cancel
Save