mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Wfh/rm num repetitions (#9425)
Makes it hard to do test run comparison views and we'd probably want to just run multiple runs right now
This commit is contained in:
parent
eee0d1d0dd
commit
c29fbede59
@ -8,6 +8,7 @@ import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
@ -662,7 +663,6 @@ async def _arun_chain(
|
||||
async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MCF,
|
||||
n_repetitions: int,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
@ -673,7 +673,6 @@ async def _arun_llm_or_chain(
|
||||
Args:
|
||||
example: The example to run.
|
||||
llm_or_chain_factory: The Chain or language model constructor to run.
|
||||
n_repetitions: The number of times to run the model on each example.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
input_mapper: Optional function to map the input to the expected format.
|
||||
@ -694,31 +693,28 @@ async def _arun_llm_or_chain(
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
)
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = await _arun_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
output = await _arun_chain(
|
||||
chain,
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id}. Error: {e}"
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = await _arun_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append({"Error": str(e)})
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
output = await _arun_chain(
|
||||
chain,
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"{chain_or_llm} failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
@ -822,7 +818,6 @@ async def _arun_on_examples(
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
num_repetitions: int = 1,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
@ -841,9 +836,6 @@ async def _arun_on_examples(
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Optional evaluation configuration to use when evaluating
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
project_name: Project name to use when tracing runs.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
@ -873,7 +865,6 @@ async def _arun_on_examples(
|
||||
result = await _arun_llm_or_chain(
|
||||
example,
|
||||
wrapped_model,
|
||||
num_repetitions,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
@ -983,7 +974,6 @@ def _run_chain(
|
||||
def _run_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MCF,
|
||||
n_repetitions: int,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
@ -995,7 +985,6 @@ def _run_llm_or_chain(
|
||||
Args:
|
||||
example: The example to run.
|
||||
llm_or_chain_factory: The Chain or language model constructor to run.
|
||||
n_repetitions: The number of times to run the model on each example.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
|
||||
@ -1016,32 +1005,31 @@ def _run_llm_or_chain(
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
)
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = _run_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
output = _run_chain(
|
||||
chain,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id} with inputs:"
|
||||
f" {example.inputs}.\nError: {e}",
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = _run_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append({"Error": str(e)})
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
output = _run_chain(
|
||||
chain,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id} with inputs:"
|
||||
f" {example.inputs}.\nError: {e}",
|
||||
)
|
||||
outputs.append({"Error": str(e)})
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
@ -1055,7 +1043,6 @@ def _run_on_examples(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
num_repetitions: int = 1,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
@ -1073,9 +1060,6 @@ def _run_on_examples(
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Optional evaluation configuration to use when evaluating
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
@ -1110,7 +1094,6 @@ def _run_on_examples(
|
||||
result = _run_llm_or_chain(
|
||||
example,
|
||||
wrapped_model,
|
||||
num_repetitions,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
@ -1158,11 +1141,11 @@ async def arun_on_dataset(
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
num_repetitions: int = 1,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously run the Chain or language model on a dataset
|
||||
@ -1177,9 +1160,6 @@ async def arun_on_dataset(
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Optional evaluation configuration to use when evaluating
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
@ -1274,6 +1254,13 @@ async def arun_on_dataset(
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"The following arguments are deprecated and will "
|
||||
"be removed in a future release: "
|
||||
f"{kwargs.keys()}.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
@ -1282,7 +1269,6 @@ async def arun_on_dataset(
|
||||
examples,
|
||||
wrapped_model,
|
||||
concurrency_level=concurrency_level,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
@ -1323,12 +1309,12 @@ def run_on_dataset(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
num_repetitions: int = 1,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the Chain or language model on a dataset and store traces
|
||||
@ -1344,9 +1330,6 @@ def run_on_dataset(
|
||||
evaluation: Configuration for evaluators to run on the
|
||||
results of the chain
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
@ -1441,6 +1424,13 @@ def run_on_dataset(
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"The following arguments are deprecated and "
|
||||
"will be removed in a future release: "
|
||||
f"{kwargs.keys()}.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
@ -1449,7 +1439,6 @@ def run_on_dataset(
|
||||
client,
|
||||
examples,
|
||||
wrapped_model,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
@ -1464,7 +1453,6 @@ def run_on_dataset(
|
||||
examples,
|
||||
wrapped_model,
|
||||
concurrency_level=concurrency_level,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
|
@ -181,15 +181,12 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
assert "the wrong input" in inputs
|
||||
return {"the right input": inputs["the wrong input"]}
|
||||
|
||||
result = _run_llm_or_chain(
|
||||
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
|
||||
)
|
||||
result = _run_llm_or_chain(example, lambda: mock_chain, input_mapper=input_mapper)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"output": "2", "the right input": "1"}
|
||||
bad_result = _run_llm_or_chain(
|
||||
example,
|
||||
lambda: mock_chain,
|
||||
n_repetitions=1,
|
||||
)
|
||||
assert len(bad_result) == 1
|
||||
assert "Error" in bad_result[0]
|
||||
@ -200,9 +197,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
return "the right input"
|
||||
|
||||
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
|
||||
result = _run_llm_or_chain(
|
||||
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
|
||||
)
|
||||
result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
|
||||
assert len(result) == 1
|
||||
llm_result = result[0]
|
||||
assert isinstance(llm_result, str)
|
||||
@ -302,14 +297,11 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
return [{"result": f"Result for example {example.id}"}]
|
||||
|
||||
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
|
||||
proj = mock.MagicMock()
|
||||
@ -327,20 +319,17 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = Client(api_url="http://localhost:1984", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
chain.input_keys = ["foothing"]
|
||||
num_repetitions = 3
|
||||
results = await arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
concurrency_level=2,
|
||||
project_name="test_project",
|
||||
num_repetitions=num_repetitions,
|
||||
client=client,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||
for _ in range(num_repetitions)
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"} for _ in range(1)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user