Wfh/allow nonparallel (#10914)

This commit is contained in:
William FH 2023-09-21 20:21:01 -07:00 committed by GitHub
parent bb3e6cb427
commit ee8653f62c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -137,19 +137,6 @@ def _wrap_in_chain_factory(
"(memory=new_memory, ...)\n\n" "(memory=new_memory, ...)\n\n"
f'run_on_dataset("{dataset_name}", chain_constructor, ...)' f'run_on_dataset("{dataset_name}", chain_constructor, ...)'
) )
logger.warning(
"Directly passing in a chain is not recommended as chains may have state."
" This can lead to unexpected behavior as the "
"same chain instance could be used across multiple datasets. Instead,"
" please pass a chain constructor that creates a new "
"chain with fresh memory each time it is called. This will safeguard"
" against information leakage between dataset examples. "
"\nFor example:\n\n"
"def chain_constructor():\n"
f" return {chain_class}(memory=new_memory, ...)\n\n"
f'run_on_dataset("{dataset_name}", chain_constructor, ...)'
)
return lambda: chain return lambda: chain
elif isinstance(llm_or_chain_factory, BaseLanguageModel): elif isinstance(llm_or_chain_factory, BaseLanguageModel):
return llm_or_chain_factory return llm_or_chain_factory
@ -653,12 +640,9 @@ async def _arun_chain(
) -> Union[dict, str]: ) -> Union[dict, str]:
"""Run a chain asynchronously on inputs.""" """Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs) inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain): if isinstance(chain, Chain) and isinstance(inputs_, dict) and len(inputs_) == 1:
if isinstance(inputs_, dict) and len(inputs_) == 1: val = next(iter(inputs_.values()))
val = next(iter(inputs_.values())) output = await chain.acall(val, callbacks=callbacks, tags=tags)
output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.ainvoke(inputs_, config=runnable_config) output = await chain.ainvoke(inputs_, config=runnable_config)
@ -781,12 +765,9 @@ def _run_chain(
) -> Union[Dict, str]: ) -> Union[Dict, str]:
"""Run a chain on inputs.""" """Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs) inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain): if isinstance(chain, Chain) and isinstance(inputs_, dict) and len(inputs_) == 1:
if isinstance(inputs_, dict) and len(inputs_) == 1: val = next(iter(inputs_.values()))
val = next(iter(inputs_.values())) output = chain(val, callbacks=callbacks, tags=tags)
output = chain(val, callbacks=callbacks, tags=tags)
else:
output = chain(inputs_, callbacks=callbacks, tags=tags)
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain.invoke(inputs_, config=runnable_config) output = chain.invoke(inputs_, config=runnable_config)
@ -1251,18 +1232,29 @@ def run_on_dataset(
input_mapper, input_mapper,
concurrency_level, concurrency_level,
) )
with runnable_config.get_executor_for_config(configs[0]) as executor: if concurrency_level == 0:
batch_results = list( batch_results = [
executor.map( _run_llm_or_chain(
functools.partial( example,
_run_llm_or_chain, config,
llm_or_chain_factory=wrapped_model, llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper, input_mapper=input_mapper,
), )
examples, for example, config in zip(examples, configs)
configs, ]
else:
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
)
) )
)
results = _collect_test_results(examples, batch_results, configs, project_name) results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose: if verbose: