From e76e68b211839743154c7a2b161bb5087e88967c Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Wed, 24 May 2023 14:06:03 -0700 Subject: [PATCH] Add Delete Session Method (#5193) --- langchain/client/langchain.py | 20 ++++++- tests/integration_tests/client/__init__.py | 0 tests/integration_tests/client/test_client.py | 52 +++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 tests/integration_tests/client/__init__.py create mode 100644 tests/integration_tests/client/test_client.py diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 4f0191ed..bd51608e 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -200,7 +200,7 @@ class LangChainPlusClient(BaseSettings): return Dataset(**result) @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) - def read_run(self, run_id: str) -> Run: + def read_run(self, run_id: Union[str, UUID]) -> Run: """Read a run from the LangChain+ API.""" response = self._get(f"/runs/{run_id}") raise_for_status_with_text(response) @@ -268,6 +268,22 @@ class LangChainPlusClient(BaseSettings): raise_for_status_with_text(response) yield from [TracerSession(**session) for session in response.json()] + @xor_args(("session_name", "session_id")) + def delete_session( + self, *, session_name: Optional[str] = None, session_id: Optional[str] = None + ) -> None: + """Delete a session from the LangChain+ API.""" + if session_name is not None: + session_id = self.read_session(session_name=session_name).id + elif session_id is None: + raise ValueError("Must provide session_name or session_id") + response = requests.delete( + self.api_url + f"/sessions/{session_id}", + headers=self._headers, + ) + raise_for_status_with_text(response) + return None + def create_dataset(self, dataset_name: str, description: str) -> Dataset: """Create a dataset in the LangChain+ API.""" dataset = DatasetCreate( @@ -360,7 +376,7 @@ class LangChainPlusClient(BaseSettings): return Example(**result) @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) - def read_example(self, example_id: str) -> Example: + def read_example(self, example_id: Union[str, UUID]) -> Example: """Read an example from the LangChain+ API.""" response = self._get(f"/examples/{example_id}") raise_for_status_with_text(response) diff --git a/tests/integration_tests/client/__init__.py b/tests/integration_tests/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/client/test_client.py b/tests/integration_tests/client/test_client.py new file mode 100644 index 00000000..66ccb29e --- /dev/null +++ b/tests/integration_tests/client/test_client.py @@ -0,0 +1,52 @@ +"""LangChain+ langchain_client Integration Tests.""" +from uuid import uuid4 + +import pytest +from tenacity import RetryError + +from langchain.callbacks.manager import tracing_v2_enabled +from langchain.client import LangChainPlusClient +from langchain.tools.base import tool + + +@pytest.fixture +def langchain_client(monkeypatch: pytest.MonkeyPatch) -> LangChainPlusClient: + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") + return LangChainPlusClient() + + +def test_sessions( + langchain_client: LangChainPlusClient, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test sessions.""" + session_names = set([session.name for session in langchain_client.list_sessions()]) + new_session = f"Session {uuid4()}" + assert new_session not in session_names + + @tool + def example_tool() -> str: + """Call me, maybe.""" + return "test_tool" + + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") + with tracing_v2_enabled(session_name=new_session): + example_tool({}) + session = langchain_client.read_session(session_name=new_session) + assert session.name == new_session + session_names = set([sess.name for sess in langchain_client.list_sessions()]) + assert new_session in session_names + runs = list(langchain_client.list_runs(session_name=new_session)) + session_id_runs = list(langchain_client.list_runs(session_id=session.id)) + assert len(runs) == len(session_id_runs) == 1 + assert runs[0].id == session_id_runs[0].id + langchain_client.delete_session(session_name=new_session) + + with pytest.raises(RetryError): + langchain_client.read_session(session_name=new_session) + assert new_session not in set( + [sess.name for sess in langchain_client.list_sessions()] + ) + with pytest.raises(RetryError): + langchain_client.delete_session(session_name=new_session) + with pytest.raises(RetryError): + langchain_client.read_run(run_id=str(runs[0].id))