Wfh/allow nonparallel (#10914)

pull/10920/head
William FH 11 months ago committed by GitHub
parent bb3e6cb427
commit ee8653f62c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -137,19 +137,6 @@ def _wrap_in_chain_factory(
"(memory=new_memory, ...)\n\n"
f'run_on_dataset("{dataset_name}", chain_constructor, ...)'
)
logger.warning(
"Directly passing in a chain is not recommended as chains may have state."
" This can lead to unexpected behavior as the "
"same chain instance could be used across multiple datasets. Instead,"
" please pass a chain constructor that creates a new "
"chain with fresh memory each time it is called. This will safeguard"
" against information leakage between dataset examples. "
"\nFor example:\n\n"
"def chain_constructor():\n"
f" return {chain_class}(memory=new_memory, ...)\n\n"
f'run_on_dataset("{dataset_name}", chain_constructor, ...)'
)
return lambda: chain
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
return llm_or_chain_factory
@ -653,12 +640,9 @@ async def _arun_chain(
) -> Union[dict, str]:
"""Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
if isinstance(chain, Chain) and isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.ainvoke(inputs_, config=runnable_config)
@ -781,12 +765,9 @@ def _run_chain(
) -> Union[Dict, str]:
"""Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = chain(val, callbacks=callbacks, tags=tags)
else:
output = chain(inputs_, callbacks=callbacks, tags=tags)
if isinstance(chain, Chain) and isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = chain(val, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain.invoke(inputs_, config=runnable_config)
@ -1251,18 +1232,29 @@ def run_on_dataset(
input_mapper,
concurrency_level,
)
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
if concurrency_level == 0:
batch_results = [
_run_llm_or_chain(
example,
config,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
)
for example, config in zip(examples, configs)
]
else:
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
)
)
)
results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:

Loading…
Cancel
Save