Add concurrency support for run_on_dataset (#8841)

Long-term, would be better to use the lower-level batch() method(s) but
it may take me a bit longer to clean up. This unblocks in the meantime,
though it may fail when the evaluated chain raises a
`NotImplementedError` for a corresponding async method
This commit is contained in:
William FH 2023-08-07 09:24:48 -07:00 committed by GitHub
parent fc2f450f2d
commit 91be7eee66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1278,6 +1278,27 @@ async def arun_on_dataset(
} }
def _handle_coroutine(coro: Coroutine) -> Any:
"""
Handles a coroutine from a sync context.
Args:
coro (asyncio.coroutine): The coroutine to be handled.
Returns:
any: The result of the executed coroutine.
"""
# Check if there's a running event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError: # No event loop
return asyncio.run(coro)
if loop.is_running():
return loop.create_task(coro)
else:
return asyncio.run(coro)
def run_on_dataset( def run_on_dataset(
client: Client, client: Client,
dataset_name: str, dataset_name: str,
@ -1285,6 +1306,7 @@ def run_on_dataset(
*, *,
evaluation: Optional[RunEvalConfig] = None, evaluation: Optional[RunEvalConfig] = None,
num_repetitions: int = 1, num_repetitions: int = 1,
concurrency_level: int = 5,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
@ -1303,6 +1325,7 @@ def run_on_dataset(
independent calls on each example without carrying over state. independent calls on each example without carrying over state.
evaluation: Configuration for evaluators to run on the evaluation: Configuration for evaluators to run on the
results of the chain results of the chain
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example. num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence This is useful when testing success rates or generating confidence
intervals. intervals.
@ -1403,18 +1426,35 @@ def run_on_dataset(
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
results = _run_on_examples( if concurrency_level in (0, 1):
client, results = _run_on_examples(
examples, client,
llm_or_chain_factory, examples,
num_repetitions=num_repetitions, llm_or_chain_factory,
project_name=project_name, num_repetitions=num_repetitions,
verbose=verbose, project_name=project_name,
tags=tags, verbose=verbose,
evaluation=evaluation, tags=tags,
input_mapper=input_mapper, evaluation=evaluation,
data_type=dataset.data_type, input_mapper=input_mapper,
) data_type=dataset.data_type,
)
else:
# TODO: Use runnables and the batch method
coro = _arun_on_examples(
client,
examples,
llm_or_chain_factory,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
)
results = _handle_coroutine(coro)
return { return {
"project_name": project_name, "project_name": project_name,
"results": results, "results": results,