From 00f276d23faf6ebe50965f1ef5c6d1c7e5f3b79b Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:31:38 -0700 Subject: [PATCH] Run eval in eval mode (#6447) For the `run_on_dataset` sessions --- langchain/client/runner_utils.py | 4 +++- poetry.lock | 8 ++++---- pyproject.toml | 2 +- tests/unit_tests/client/test_runner_utils.py | 5 +++++ 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py index bb295f39..99066ea6 100644 --- a/langchain/client/runner_utils.py +++ b/langchain/client/runner_utils.py @@ -278,7 +278,7 @@ async def arun_on_examples( results: Dict[str, List[Any]] = {} async def process_example( - example: Example, tracer: LangChainTracer, job_state: dict + example: Example, tracer: Optional[LangChainTracer], job_state: dict ) -> None: """Process a single example.""" result = await _arun_llm_or_chain( @@ -466,6 +466,7 @@ async def arun_on_dataset( """ client_ = client or LangChainPlusClient() session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name) + client_.create_session(session_name, mode="eval") dataset = client_.read_dataset(dataset_name=dataset_name) examples = client_.list_examples(dataset_id=str(dataset.id)) @@ -517,6 +518,7 @@ def run_on_dataset( """ client_ = client or LangChainPlusClient() session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name) + client_.create_session(session_name, mode="eval") dataset = client_.read_dataset(dataset_name=dataset_name) examples = client_.list_examples(dataset_id=str(dataset.id)) results = run_on_examples( diff --git a/poetry.lock b/poetry.lock index 826cb88d..729f4001 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3845,13 +3845,13 @@ tests = ["doctest", "pytest", "pytest-mock"] [[package]] name = "langchainplus-sdk" -version = "0.0.10" +version = "0.0.15" description = "Client library to connect to the LangChainPlus LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchainplus_sdk-0.0.10-py3-none-any.whl", hash = "sha256:6ea4013a92a4c33a61d22deb49620577c592a79ee44038b2c751032a71cbc7b6"}, - {file = "langchainplus_sdk-0.0.10.tar.gz", hash = "sha256:4f810b38df74a99d01e5723e653da02f05df3ee922971cccabc365d00c33dbf6"}, + {file = "langchainplus_sdk-0.0.15-py3-none-any.whl", hash = "sha256:e69bdbc8af6007ef2f774248d2483bbaf2d75712b1acc9ea50eda3b9f6dc567d"}, + {file = "langchainplus_sdk-0.0.15.tar.gz", hash = "sha256:ce40e9e3b6d42741f0a2aa89f83a12f2648f38690a9dd57e5fe3a56f2f232908"}, ] [package.dependencies] @@ -10864,4 +10864,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "abfd5265cf134d614666453b6f4ec958bcf8de6447b4bdad091c333528162d04" +content-hash = "1009d76e766a610a009cf900800f854b3a7901d680226fabf8c4e82e98a83c44" diff --git a/pyproject.toml b/pyproject.toml index 79c7a56d..21f1b96d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,7 @@ singlestoredb = {version = "^0.6.1", optional = true} pyspark = {version = "^3.4.0", optional = true} tigrisdb = {version = "^1.0.0b6", optional = true} nebula3-python = {version = "^3.4.0", optional = true} -langchainplus-sdk = ">=0.0.9" +langchainplus-sdk = ">=0.0.13" awadb = {version = "^0.3.3", optional = true} azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev", optional = true} # now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out. diff --git a/tests/unit_tests/client/test_runner_utils.py b/tests/unit_tests/client/test_runner_utils.py index 4487657e..2c0ccd44 100644 --- a/tests/unit_tests/client/test_runner_utils.py +++ b/tests/unit_tests/client/test_runner_utils.py @@ -176,12 +176,17 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: {"result": f"Result for example {example.id}"} for _ in range(n_repetitions) ] + def mock_create_session(*args: Any, **kwargs: Any) -> None: + pass + with mock.patch.object( LangChainPlusClient, "read_dataset", new=mock_read_dataset ), mock.patch.object( LangChainPlusClient, "list_examples", new=mock_list_examples ), mock.patch( "langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain + ), mock.patch.object( + LangChainPlusClient, "create_session", new=mock_create_session ): client = LangChainPlusClient(api_url="http://localhost:1984", api_key="123") chain = mock.MagicMock()