Update Tracer Auth / Reduce Num Calls (#5517)

Update the session creation and calls

---------

Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
searx_updates
Zander Chase 12 months ago committed by GitHub
parent 949729ff5c
commit 20ec1173f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,7 +23,6 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema import (
@ -99,26 +98,21 @@ def tracing_v2_enabled(
session_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tenant_id: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
) -> Generator[TracerSession, None, None]:
) -> Generator[None, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
warnings.warn(
"The experimental tracing v2 is in development. "
"The tracing v2 API is in development. "
"This is not yet stable and may change in the future."
)
if isinstance(example_id, str):
example_id = UUID(example_id)
cb = LangChainTracer(
tenant_id=tenant_id,
session_name=session_name,
example_id=example_id,
session_extra=session_extra,
session_name=session_name,
)
session = cb.ensure_session()
tracing_v2_callback_var.set(cb)
yield session
yield
tracing_v2_callback_var.set(None)
@ -919,7 +913,6 @@ def _configure(
else:
try:
handler = LangChainTracer(session_name=tracer_session)
handler.ensure_session()
callback_manager.add_handler(handler, True)
except Exception as e:
logger.warning(

@ -25,10 +25,8 @@ from langchain.callbacks.tracers.schemas import (
RunTypeEnum,
RunUpdate,
TracerSession,
TracerSessionCreate,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text
logger = logging.getLogger(__name__)
@ -65,49 +63,13 @@ retry_decorator = retry(
)
@retry_decorator
def _get_tenant_id(
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
) -> str:
"""Get the tenant ID for the LangChain API."""
tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID")
if tenant_id_:
return tenant_id_
endpoint_ = endpoint or get_endpoint()
headers_ = headers or get_headers()
response = None
try:
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to get tenant ID from LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to get tenant ID from LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to get tenant ID from LangChain API. {e}"
) from e
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint_}")
return tenants[0]["id"]
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(
self,
tenant_id: Optional[str] = None,
example_id: Optional[UUID] = None,
session_name: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer."""
@ -115,10 +77,8 @@ class LangChainTracer(BaseTracer):
self.session: Optional[TracerSession] = None
self._endpoint = get_endpoint()
self._headers = get_headers()
self.tenant_id = tenant_id
self.example_id = example_id
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
self.session_extra = session_extra
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)
@ -149,62 +109,20 @@ class LangChainTracer(BaseTracer):
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)
def ensure_tenant_id(self) -> str:
"""Load or use the tenant ID."""
tenant_id = self.tenant_id or _get_tenant_id(
self.tenant_id, self._endpoint, self._headers
)
self.tenant_id = tenant_id
return tenant_id
@retry_decorator
def ensure_session(self) -> TracerSession:
"""Upsert a session."""
if self.session is not None:
return self.session
tenant_id = self.ensure_tenant_id()
url = f"{self._endpoint}/sessions?upsert=true"
session_create = TracerSessionCreate(
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
)
response = None
try:
response = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert session to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to upsert session to LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to upsert session to LangChain API. {e}"
) from e
self.session = TracerSession(**response.json())
return self.session
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
"""The Langchain Tracer uses Post/Patch rather than persist."""
@retry_decorator
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
session = self.ensure_session()
if run.parent_run_id is None:
run.reference_example_id = self.example_id
run_dict = run.dict()
del run_dict["child_runs"]
run_create = RunCreate(**run_dict, session_id=session.id)
run_create = RunCreate(**run_dict, session_name=self.session_name)
response = None
try:
# TODO: Add retries when async
response = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),

@ -36,12 +36,6 @@ class TracerSessionBase(TracerSessionV1Base):
tenant_id: UUID
class TracerSessionCreate(TracerSessionBase):
"""A creation class for TracerSession."""
id: Optional[UUID]
class TracerSession(TracerSessionBase):
"""TracerSessionV1 schema for the V2 API."""
@ -136,7 +130,7 @@ class Run(RunBase):
class RunCreate(RunBase):
name: str
session_id: UUID
session_name: Optional[str] = None
@root_validator(pre=True)
def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:

@ -10,7 +10,6 @@ from typing import (
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
@ -26,7 +25,8 @@ from requests import Response
from tenacity import retry, stop_after_attempt, wait_fixed
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.tracers.schemas import Run, TracerSession
from langchain.callbacks.tracers.schemas import Run as TracerRun
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.chains.base import Chain
from langchain.client.models import (
APIFeedbackSource,
@ -54,6 +54,10 @@ logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
class Run(TracerRun):
id: UUID
def _get_link_stem(url: str) -> str:
scheme = urlsplit(url).scheme
netloc_prefix = urlsplit(url).netloc.split(":")[0]
@ -75,7 +79,6 @@ class LangChainPlusClient(BaseSettings):
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
tenant_id: Optional[str] = None
@root_validator(pre=True)
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@ -87,31 +90,8 @@ class LangChainPlusClient(BaseSettings):
raise ValueError(
"API key must be provided when using hosted LangChain+ API"
)
tenant_id = values.get("tenant_id")
if not tenant_id:
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
api_url, api_key
)
return values
@staticmethod
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
"""Get the tenant ID from the seeded tenant."""
url = f"{api_url}/tenants"
headers = {"x-api-key": api_key} if api_key else {}
response = requests.get(url, headers=headers)
try:
raise_for_status_with_text(response)
except Exception as e:
raise ValueError(
"Unable to get default tenant ID. Please manually provide."
) from e
results: List[dict] = response.json()
if len(results) == 0:
raise ValueError("No seeded tenant found")
return results[0]["id"]
@staticmethod
def _get_session_name(
session_name: Optional[str],
@ -149,18 +129,10 @@ class LangChainPlusClient(BaseSettings):
headers["x-api-key"] = self.api_key
return headers
@property
def query_params(self) -> Dict[str, Any]:
"""Get the headers for the API request."""
return {"tenant_id": self.tenant_id}
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
"""Make a GET request."""
query_params = self.query_params
if params:
query_params.update(params)
return requests.get(
f"{self.api_url}{path}", headers=self._headers, params=query_params
f"{self.api_url}{path}", headers=self._headers, params=params
)
def upload_dataframe(
@ -192,7 +164,6 @@ class LangChainPlusClient(BaseSettings):
"input_keys": ",".join(input_keys),
"output_keys": ",".join(output_keys),
"description": description,
"tenant_id": self.tenant_id,
}
response = requests.post(
self.api_url + "/datasets/upload",
@ -244,7 +215,7 @@ class LangChainPlusClient(BaseSettings):
) -> TracerSession:
"""Read a session from the LangChain+ API."""
path = "/sessions"
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
params: Dict[str, Any] = {"limit": 1}
if session_id is not None:
path += f"/{session_id}"
elif session_name is not None:
@ -291,7 +262,6 @@ class LangChainPlusClient(BaseSettings):
) -> Dataset:
"""Create a dataset in the LangChain+ API."""
dataset = DatasetCreate(
tenant_id=self.tenant_id,
name=dataset_name,
description=description,
)
@ -309,7 +279,7 @@ class LangChainPlusClient(BaseSettings):
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
) -> Dataset:
path = "/datasets"
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
params: Dict[str, Any] = {"limit": 1}
if dataset_id is not None:
path += f"/{dataset_id}"
elif dataset_name is not None:

@ -49,7 +49,6 @@ class ExampleUpdate(BaseModel):
class DatasetBase(BaseModel):
"""Dataset base model."""
tenant_id: UUID
name: str
description: Optional[str] = None
@ -68,6 +67,7 @@ class Dataset(DatasetBase):
"""Dataset ORM model."""
id: UUID
tenant_id: UUID
created_at: datetime
modified_at: Optional[datetime] = Field(default=None)

@ -214,7 +214,6 @@ async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChain
"""
if session_name:
tracer = LangChainTracer(session_name=session_name)
tracer.ensure_session()
return tracer
else:
return None

@ -148,8 +148,7 @@ def test_tracing_v2_context_manager() -> None:
)
if "LANGCHAIN_TRACING_V2" in os.environ:
del os.environ["LANGCHAIN_TRACING_V2"]
with tracing_v2_enabled() as session:
assert session
with tracing_v2_enabled():
agent.run(questions[0]) # this should be traced
agent.run(questions[0]) # this should not be traced

@ -2,14 +2,12 @@
import uuid
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union
from unittest import mock
import pytest
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.chains.base import Chain
from langchain.client.langchain import (
LangChainPlusClient,
@ -46,39 +44,23 @@ def test_is_localhost() -> None:
assert not _is_localhost("http://example.com:8000")
def test_validate_api_key_if_hosted() -> None:
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
return _TENANT_ID
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
with pytest.raises(ValueError, match="API key must be provided"):
LangChainPlusClient(api_url="http://www.example.com")
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
with pytest.raises(ValueError, match="API key must be provided"):
LangChainPlusClient(api_url="http://www.example.com")
client = LangChainPlusClient(api_url="http://localhost:8000")
assert client.api_url == "http://localhost:8000"
assert client.api_key is None
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client = LangChainPlusClient(api_url="http://localhost:8000")
assert client.api_url == "http://localhost:8000"
assert client.api_key is None
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
assert client._headers == {"x-api-key": "123"}
def test_headers() -> None:
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
return _TENANT_ID
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
assert client._headers == {"x-api-key": "123"}
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
assert client_no_key._headers == {}
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
assert client_no_key._headers == {}
@mock.patch("langchain.client.langchain.requests.post")
@ -112,7 +94,8 @@ def test_upload_csv(mock_post: mock.Mock) -> None:
mock_post.return_value = mock_response
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
api_url="http://localhost:8000",
api_key="123",
)
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
@ -196,22 +179,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch(
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainTracer, "ensure_session", new=mock_ensure_session
):
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
)
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
chain = mock.MagicMock()
num_repetitions = 3
results = await client.arun_on_dataset(

Loading…
Cancel
Save