diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index e85665e1..ebcbca8c 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -23,7 +23,6 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1 -from langchain.callbacks.tracers.schemas import TracerSession from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.callbacks.tracers.wandb import WandbTracer from langchain.schema import ( @@ -99,26 +98,21 @@ def tracing_v2_enabled( session_name: Optional[str] = None, *, example_id: Optional[Union[str, UUID]] = None, - tenant_id: Optional[str] = None, - session_extra: Optional[Dict[str, Any]] = None, -) -> Generator[TracerSession, None, None]: +) -> Generator[None, None, None]: """Get the experimental tracer handler in a context manager.""" # Issue a warning that this is experimental warnings.warn( - "The experimental tracing v2 is in development. " + "The tracing v2 API is in development. " "This is not yet stable and may change in the future." ) if isinstance(example_id, str): example_id = UUID(example_id) cb = LangChainTracer( - tenant_id=tenant_id, - session_name=session_name, example_id=example_id, - session_extra=session_extra, + session_name=session_name, ) - session = cb.ensure_session() tracing_v2_callback_var.set(cb) - yield session + yield tracing_v2_callback_var.set(None) @@ -919,7 +913,6 @@ def _configure( else: try: handler = LangChainTracer(session_name=tracer_session) - handler.ensure_session() callback_manager.add_handler(handler, True) except Exception as e: logger.warning( diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index cda5e8d8..bbbccfe8 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -25,10 +25,8 @@ from langchain.callbacks.tracers.schemas import ( RunTypeEnum, RunUpdate, TracerSession, - TracerSessionCreate, ) from langchain.schema import BaseMessage, messages_to_dict -from langchain.utils import raise_for_status_with_text logger = logging.getLogger(__name__) @@ -65,49 +63,13 @@ retry_decorator = retry( ) -@retry_decorator -def _get_tenant_id( - tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict] -) -> str: - """Get the tenant ID for the LangChain API.""" - tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID") - if tenant_id_: - return tenant_id_ - endpoint_ = endpoint or get_endpoint() - headers_ = headers or get_headers() - response = None - try: - response = requests.get(endpoint_ + "/tenants", headers=headers_) - raise_for_status_with_text(response) - except HTTPError as e: - if response is not None and response.status_code == 500: - raise LangChainTracerAPIError( - f"Failed to get tenant ID from LangChain API. {e}" - ) - else: - raise LangChainTracerUserError( - f"Failed to get tenant ID from LangChain API. {e}" - ) - except Exception as e: - raise LangChainTracerError( - f"Failed to get tenant ID from LangChain API. {e}" - ) from e - - tenants: List[Dict[str, Any]] = response.json() - if not tenants: - raise ValueError(f"No tenants found for URL {endpoint_}") - return tenants[0]["id"] - - class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" def __init__( self, - tenant_id: Optional[str] = None, example_id: Optional[UUID] = None, session_name: Optional[str] = None, - session_extra: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Initialize the LangChain tracer.""" @@ -115,10 +77,8 @@ class LangChainTracer(BaseTracer): self.session: Optional[TracerSession] = None self._endpoint = get_endpoint() self._headers = get_headers() - self.tenant_id = tenant_id self.example_id = example_id self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default") - self.session_extra = session_extra # set max_workers to 1 to process tasks in order self.executor = ThreadPoolExecutor(max_workers=1) @@ -149,62 +109,20 @@ class LangChainTracer(BaseTracer): self._start_trace(chat_model_run) self._on_chat_model_start(chat_model_run) - def ensure_tenant_id(self) -> str: - """Load or use the tenant ID.""" - tenant_id = self.tenant_id or _get_tenant_id( - self.tenant_id, self._endpoint, self._headers - ) - self.tenant_id = tenant_id - return tenant_id - - @retry_decorator - def ensure_session(self) -> TracerSession: - """Upsert a session.""" - if self.session is not None: - return self.session - tenant_id = self.ensure_tenant_id() - url = f"{self._endpoint}/sessions?upsert=true" - session_create = TracerSessionCreate( - name=self.session_name, extra=self.session_extra, tenant_id=tenant_id - ) - response = None - try: - response = requests.post( - url, - data=session_create.json(), - headers=self._headers, - ) - response.raise_for_status() - except HTTPError as e: - if response is not None and response.status_code == 500: - raise LangChainTracerAPIError( - f"Failed to upsert session to LangChain API. {e}" - ) - else: - raise LangChainTracerUserError( - f"Failed to upsert session to LangChain API. {e}" - ) - except Exception as e: - raise LangChainTracerError( - f"Failed to upsert session to LangChain API. {e}" - ) from e - self.session = TracerSession(**response.json()) - return self.session - def _persist_run(self, run: Run) -> None: - """Persist a run.""" + """The Langchain Tracer uses Post/Patch rather than persist.""" @retry_decorator def _persist_run_single(self, run: Run) -> None: """Persist a run.""" - session = self.ensure_session() if run.parent_run_id is None: run.reference_example_id = self.example_id run_dict = run.dict() del run_dict["child_runs"] - run_create = RunCreate(**run_dict, session_id=session.id) + run_create = RunCreate(**run_dict, session_name=self.session_name) response = None try: + # TODO: Add retries when async response = requests.post( f"{self._endpoint}/runs", data=run_create.json(), diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index 8c5a7d92..4816b8b9 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -36,12 +36,6 @@ class TracerSessionBase(TracerSessionV1Base): tenant_id: UUID -class TracerSessionCreate(TracerSessionBase): - """A creation class for TracerSession.""" - - id: Optional[UUID] - - class TracerSession(TracerSessionBase): """TracerSessionV1 schema for the V2 API.""" @@ -136,7 +130,7 @@ class Run(RunBase): class RunCreate(RunBase): name: str - session_id: UUID + session_name: Optional[str] = None @root_validator(pre=True) def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]: diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 837cff47..29d13924 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -10,7 +10,6 @@ from typing import ( Callable, Dict, Iterator, - List, Mapping, Optional, Sequence, @@ -26,7 +25,8 @@ from requests import Response from tenacity import retry, stop_after_attempt, wait_fixed from langchain.base_language import BaseLanguageModel -from langchain.callbacks.tracers.schemas import Run, TracerSession +from langchain.callbacks.tracers.schemas import Run as TracerRun +from langchain.callbacks.tracers.schemas import TracerSession from langchain.chains.base import Chain from langchain.client.models import ( APIFeedbackSource, @@ -54,6 +54,10 @@ logger = logging.getLogger(__name__) MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] +class Run(TracerRun): + id: UUID + + def _get_link_stem(url: str) -> str: scheme = urlsplit(url).scheme netloc_prefix = urlsplit(url).netloc.split(":")[0] @@ -75,7 +79,6 @@ class LangChainPlusClient(BaseSettings): api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY") api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT") - tenant_id: Optional[str] = None @root_validator(pre=True) def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -87,31 +90,8 @@ class LangChainPlusClient(BaseSettings): raise ValueError( "API key must be provided when using hosted LangChain+ API" ) - tenant_id = values.get("tenant_id") - if not tenant_id: - values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id( - api_url, api_key - ) return values - @staticmethod - @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) - def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: - """Get the tenant ID from the seeded tenant.""" - url = f"{api_url}/tenants" - headers = {"x-api-key": api_key} if api_key else {} - response = requests.get(url, headers=headers) - try: - raise_for_status_with_text(response) - except Exception as e: - raise ValueError( - "Unable to get default tenant ID. Please manually provide." - ) from e - results: List[dict] = response.json() - if len(results) == 0: - raise ValueError("No seeded tenant found") - return results[0]["id"] - @staticmethod def _get_session_name( session_name: Optional[str], @@ -149,18 +129,10 @@ class LangChainPlusClient(BaseSettings): headers["x-api-key"] = self.api_key return headers - @property - def query_params(self) -> Dict[str, Any]: - """Get the headers for the API request.""" - return {"tenant_id": self.tenant_id} - def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response: """Make a GET request.""" - query_params = self.query_params - if params: - query_params.update(params) return requests.get( - f"{self.api_url}{path}", headers=self._headers, params=query_params + f"{self.api_url}{path}", headers=self._headers, params=params ) def upload_dataframe( @@ -192,7 +164,6 @@ class LangChainPlusClient(BaseSettings): "input_keys": ",".join(input_keys), "output_keys": ",".join(output_keys), "description": description, - "tenant_id": self.tenant_id, } response = requests.post( self.api_url + "/datasets/upload", @@ -244,7 +215,7 @@ class LangChainPlusClient(BaseSettings): ) -> TracerSession: """Read a session from the LangChain+ API.""" path = "/sessions" - params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id} + params: Dict[str, Any] = {"limit": 1} if session_id is not None: path += f"/{session_id}" elif session_name is not None: @@ -291,7 +262,6 @@ class LangChainPlusClient(BaseSettings): ) -> Dataset: """Create a dataset in the LangChain+ API.""" dataset = DatasetCreate( - tenant_id=self.tenant_id, name=dataset_name, description=description, ) @@ -309,7 +279,7 @@ class LangChainPlusClient(BaseSettings): self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None ) -> Dataset: path = "/datasets" - params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id} + params: Dict[str, Any] = {"limit": 1} if dataset_id is not None: path += f"/{dataset_id}" elif dataset_name is not None: diff --git a/langchain/client/models.py b/langchain/client/models.py index ae9786f7..86f3a2fe 100644 --- a/langchain/client/models.py +++ b/langchain/client/models.py @@ -49,7 +49,6 @@ class ExampleUpdate(BaseModel): class DatasetBase(BaseModel): """Dataset base model.""" - tenant_id: UUID name: str description: Optional[str] = None @@ -68,6 +67,7 @@ class Dataset(DatasetBase): """Dataset ORM model.""" id: UUID + tenant_id: UUID created_at: datetime modified_at: Optional[datetime] = Field(default=None) diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py index 2a59fbde..bba9b971 100644 --- a/langchain/client/runner_utils.py +++ b/langchain/client/runner_utils.py @@ -214,7 +214,6 @@ async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChain """ if session_name: tracer = LangChainTracer(session_name=session_name) - tracer.ensure_session() return tracer else: return None diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index 1caf5ebc..ca33cf91 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -148,8 +148,7 @@ def test_tracing_v2_context_manager() -> None: ) if "LANGCHAIN_TRACING_V2" in os.environ: del os.environ["LANGCHAIN_TRACING_V2"] - with tracing_v2_enabled() as session: - assert session + with tracing_v2_enabled(): agent.run(questions[0]) # this should be traced agent.run(questions[0]) # this should not be traced diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py index b45ff3ca..3be03487 100644 --- a/tests/unit_tests/client/test_langchain.py +++ b/tests/unit_tests/client/test_langchain.py @@ -2,14 +2,12 @@ import uuid from datetime import datetime from io import BytesIO -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union from unittest import mock import pytest from langchain.base_language import BaseLanguageModel -from langchain.callbacks.tracers.langchain import LangChainTracer -from langchain.callbacks.tracers.schemas import TracerSession from langchain.chains.base import Chain from langchain.client.langchain import ( LangChainPlusClient, @@ -46,39 +44,23 @@ def test_is_localhost() -> None: assert not _is_localhost("http://example.com:8000") -def test_validate_api_key_if_hosted() -> None: - def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: - return _TENANT_ID +def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False) + with pytest.raises(ValueError, match="API key must be provided"): + LangChainPlusClient(api_url="http://www.example.com") - with mock.patch.object( - LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id - ): - with pytest.raises(ValueError, match="API key must be provided"): - LangChainPlusClient(api_url="http://www.example.com") + client = LangChainPlusClient(api_url="http://localhost:8000") + assert client.api_url == "http://localhost:8000" + assert client.api_key is None - with mock.patch.object( - LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id - ): - client = LangChainPlusClient(api_url="http://localhost:8000") - assert client.api_url == "http://localhost:8000" - assert client.api_key is None +def test_headers(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False) + client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123") + assert client._headers == {"x-api-key": "123"} -def test_headers() -> None: - def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: - return _TENANT_ID - - with mock.patch.object( - LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id - ): - client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123") - assert client._headers == {"x-api-key": "123"} - - with mock.patch.object( - LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id - ): - client_no_key = LangChainPlusClient(api_url="http://localhost:8000") - assert client_no_key._headers == {} + client_no_key = LangChainPlusClient(api_url="http://localhost:8000") + assert client_no_key._headers == {} @mock.patch("langchain.client.langchain.requests.post") @@ -112,7 +94,8 @@ def test_upload_csv(mock_post: mock.Mock) -> None: mock_post.return_value = mock_response client = LangChainPlusClient( - api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID + api_url="http://localhost:8000", + api_key="123", ) csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n")) @@ -196,22 +179,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: {"result": f"Result for example {example.id}"} for _ in range(n_repetitions) ] - def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession: - return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4()) - with mock.patch.object( LangChainPlusClient, "read_dataset", new=mock_read_dataset ), mock.patch.object( LangChainPlusClient, "list_examples", new=mock_list_examples ), mock.patch( "langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain - ), mock.patch.object( - LangChainTracer, "ensure_session", new=mock_ensure_session ): - monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID) - client = LangChainPlusClient( - api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID - ) + client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123") chain = mock.MagicMock() num_repetitions = 3 results = await client.arun_on_dataset(