Run eval in eval mode (#6447)

For the `run_on_dataset` sessions
master
Zander Chase 11 months ago committed by GitHub
parent 1300a4bc8c
commit 00f276d23f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -278,7 +278,7 @@ async def arun_on_examples(
results: Dict[str, List[Any]] = {}
async def process_example(
example: Example, tracer: LangChainTracer, job_state: dict
example: Example, tracer: Optional[LangChainTracer], job_state: dict
) -> None:
"""Process a single example."""
result = await _arun_llm_or_chain(
@ -466,6 +466,7 @@ async def arun_on_dataset(
"""
client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
client_.create_session(session_name, mode="eval")
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
@ -517,6 +518,7 @@ def run_on_dataset(
"""
client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
client_.create_session(session_name, mode="eval")
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
results = run_on_examples(

8
poetry.lock generated

@ -3845,13 +3845,13 @@ tests = ["doctest", "pytest", "pytest-mock"]
[[package]]
name = "langchainplus-sdk"
version = "0.0.10"
version = "0.0.15"
description = "Client library to connect to the LangChainPlus LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchainplus_sdk-0.0.10-py3-none-any.whl", hash = "sha256:6ea4013a92a4c33a61d22deb49620577c592a79ee44038b2c751032a71cbc7b6"},
{file = "langchainplus_sdk-0.0.10.tar.gz", hash = "sha256:4f810b38df74a99d01e5723e653da02f05df3ee922971cccabc365d00c33dbf6"},
{file = "langchainplus_sdk-0.0.15-py3-none-any.whl", hash = "sha256:e69bdbc8af6007ef2f774248d2483bbaf2d75712b1acc9ea50eda3b9f6dc567d"},
{file = "langchainplus_sdk-0.0.15.tar.gz", hash = "sha256:ce40e9e3b6d42741f0a2aa89f83a12f2648f38690a9dd57e5fe3a56f2f232908"},
]
[package.dependencies]
@ -10864,4 +10864,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "abfd5265cf134d614666453b6f4ec958bcf8de6447b4bdad091c333528162d04"
content-hash = "1009d76e766a610a009cf900800f854b3a7901d680226fabf8c4e82e98a83c44"

@ -105,7 +105,7 @@ singlestoredb = {version = "^0.6.1", optional = true}
pyspark = {version = "^3.4.0", optional = true}
tigrisdb = {version = "^1.0.0b6", optional = true}
nebula3-python = {version = "^3.4.0", optional = true}
langchainplus-sdk = ">=0.0.9"
langchainplus-sdk = ">=0.0.13"
awadb = {version = "^0.3.3", optional = true}
azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev", optional = true}
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out.

@ -176,12 +176,17 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_create_session(*args: Any, **kwargs: Any) -> None:
pass
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch(
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainPlusClient, "create_session", new=mock_create_session
):
client = LangChainPlusClient(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()

Loading…
Cancel
Save