Wfh/rm num repetitions (#9425)

Makes it hard to do test run comparison views and we'd probably want to
just run multiple runs right now
This commit is contained in:
William FH 2023-08-18 10:08:39 -07:00 committed by GitHub
parent eee0d1d0dd
commit c29fbede59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 89 deletions

View File

@ -8,6 +8,7 @@ import inspect
import itertools
import logging
import uuid
import warnings
from enum import Enum
from typing import (
Any,
@ -662,7 +663,6 @@ async def _arun_chain(
async def _arun_llm_or_chain(
example: Example,
llm_or_chain_factory: MCF,
n_repetitions: int,
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
@ -673,7 +673,6 @@ async def _arun_llm_or_chain(
Args:
example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map the input to the expected format.
@ -694,31 +693,28 @@ async def _arun_llm_or_chain(
chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
)
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = await _arun_llm(
llm_or_chain_factory,
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
output = await _arun_chain(
chain,
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id}. Error: {e}"
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = await _arun_llm(
llm_or_chain_factory,
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
outputs.append({"Error": str(e)})
else:
chain = llm_or_chain_factory()
output = await _arun_chain(
chain,
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(f"{chain_or_llm} failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
@ -822,7 +818,6 @@ async def _arun_on_examples(
*,
evaluation: Optional[RunEvalConfig] = None,
concurrency_level: int = 5,
num_repetitions: int = 1,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
@ -841,9 +836,6 @@ async def _arun_on_examples(
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Project name to use when tracing runs.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
@ -873,7 +865,6 @@ async def _arun_on_examples(
result = await _arun_llm_or_chain(
example,
wrapped_model,
num_repetitions,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
@ -983,7 +974,6 @@ def _run_chain(
def _run_llm_or_chain(
example: Example,
llm_or_chain_factory: MCF,
n_repetitions: int,
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
@ -995,7 +985,6 @@ def _run_llm_or_chain(
Args:
example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
@ -1016,32 +1005,31 @@ def _run_llm_or_chain(
chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
)
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = _run_llm(
llm_or_chain_factory,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
output = _run_chain(
chain,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = _run_llm(
llm_or_chain_factory,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
outputs.append({"Error": str(e)})
else:
chain = llm_or_chain_factory()
output = _run_chain(
chain,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
)
outputs.append({"Error": str(e)})
if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
@ -1055,7 +1043,6 @@ def _run_on_examples(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[RunEvalConfig] = None,
num_repetitions: int = 1,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
@ -1073,9 +1060,6 @@ def _run_on_examples(
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
@ -1110,7 +1094,6 @@ def _run_on_examples(
result = _run_llm_or_chain(
example,
wrapped_model,
num_repetitions,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
@ -1158,11 +1141,11 @@ async def arun_on_dataset(
*,
evaluation: Optional[RunEvalConfig] = None,
concurrency_level: int = 5,
num_repetitions: int = 1,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Asynchronously run the Chain or language model on a dataset
@ -1177,9 +1160,6 @@ async def arun_on_dataset(
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
@ -1274,6 +1254,13 @@ async def arun_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
if kwargs:
warnings.warn(
"The following arguments are deprecated and will "
"be removed in a future release: "
f"{kwargs.keys()}.",
DeprecationWarning,
)
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
@ -1282,7 +1269,6 @@ async def arun_on_dataset(
examples,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,
@ -1323,12 +1309,12 @@ def run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[RunEvalConfig] = None,
num_repetitions: int = 1,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Run the Chain or language model on a dataset and store traces
@ -1344,9 +1330,6 @@ def run_on_dataset(
evaluation: Configuration for evaluators to run on the
results of the chain
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
@ -1441,6 +1424,13 @@ def run_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
if kwargs:
warnings.warn(
"The following arguments are deprecated and "
"will be removed in a future release: "
f"{kwargs.keys()}.",
DeprecationWarning,
)
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
@ -1449,7 +1439,6 @@ def run_on_dataset(
client,
examples,
wrapped_model,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,
@ -1464,7 +1453,6 @@ def run_on_dataset(
examples,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
tags=tags,

View File

@ -181,15 +181,12 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
assert "the wrong input" in inputs
return {"the right input": inputs["the wrong input"]}
result = _run_llm_or_chain(
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
)
result = _run_llm_or_chain(example, lambda: mock_chain, input_mapper=input_mapper)
assert len(result) == 1
assert result[0] == {"output": "2", "the right input": "1"}
bad_result = _run_llm_or_chain(
example,
lambda: mock_chain,
n_repetitions=1,
)
assert len(bad_result) == 1
assert "Error" in bad_result[0]
@ -200,9 +197,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
return "the right input"
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
result = _run_llm_or_chain(
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
)
result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
assert len(result) == 1
llm_result = result[0]
assert isinstance(llm_result, str)
@ -302,14 +297,11 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
return [{"result": f"Result for example {example.id}"}]
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
proj = mock.MagicMock()
@ -327,20 +319,17 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
client = Client(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()
chain.input_keys = ["foothing"]
num_repetitions = 3
results = await arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
project_name="test_project",
num_repetitions=num_repetitions,
client=client,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
{"result": f"Result for example {uuid.UUID(uuid_)}"} for _ in range(1)
]
for uuid_ in uuids
}