diff --git a/langchain/client/__init__.py b/langchain/client/__init__.py
new file mode 100644
index 0000000000..89a34615a5
--- /dev/null
+++ b/langchain/client/__init__.py
@@ -0,0 +1,6 @@
+"""LangChain+ Client."""
+
+
+from langchain.client.langchain import LangChainPlusClient
+
+__all__ = ["LangChainPlusClient"]
diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py
new file mode 100644
index 0000000000..5e01e19014
--- /dev/null
+++ b/langchain/client/langchain.py
@@ -0,0 +1,544 @@
+from __future__ import annotations
+
+import asyncio
+import functools
+import logging
+import socket
+from datetime import datetime
+from io import BytesIO
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Coroutine,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+from urllib.parse import urlsplit
+from uuid import UUID
+
+import requests
+from pydantic import BaseSettings, Field, root_validator
+from requests import Response
+
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.manager import tracing_v2_enabled
+from langchain.callbacks.tracers.langchain import LangChainTracerV2
+from langchain.chains.base import Chain
+from langchain.chat_models.base import BaseChatModel
+from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
+from langchain.client.utils import parse_chat_messages
+from langchain.llms.base import BaseLLM
+from langchain.schema import ChatResult, LLMResult
+from langchain.utils import raise_for_status_with_text, xor_args
+
+if TYPE_CHECKING:
+ import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+
+def _get_link_stem(url: str) -> str:
+ scheme = urlsplit(url).scheme
+ netloc_prefix = urlsplit(url).netloc.split(":")[0]
+ return f"{scheme}://{netloc_prefix}"
+
+
+def _is_localhost(url: str) -> bool:
+ """Check if the URL is localhost."""
+ try:
+ netloc = urlsplit(url).netloc.split(":")[0]
+ ip = socket.gethostbyname(netloc)
+ return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
+ except socket.gaierror:
+ return False
+
+
+class LangChainPlusClient(BaseSettings):
+ """Client for interacting with the LangChain+ API."""
+
+ api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
+ api_url: str = Field(..., env="LANGCHAIN_ENDPOINT")
+ tenant_id: str = Field(..., env="LANGCHAIN_TENANT_ID")
+
+ @root_validator(pre=True)
+ def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Verify API key is provided if url not localhost."""
+ api_url: str = values.get("api_url", "http://localhost:8000")
+ api_key: Optional[str] = values.get("api_key")
+ if not _is_localhost(api_url):
+ if not api_key:
+ raise ValueError(
+ "API key must be provided when using hosted LangChain+ API"
+ )
+ else:
+ tenant_id = values.get("tenant_id")
+ if not tenant_id:
+ values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
+ api_url, api_key
+ )
+ return values
+
+ @staticmethod
+ def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
+ """Get the tenant ID from the seeded tenant."""
+ url = f"{api_url}/tenants"
+ headers = {"authorization": f"Bearer {api_key}"} if api_key else {}
+ response = requests.get(url, headers=headers)
+ try:
+ raise_for_status_with_text(response)
+ except Exception as e:
+ raise ValueError(
+ "Unable to get seeded tenant ID. Please manually provide."
+ ) from e
+ results: List[dict] = response.json()
+ breakpoint()
+ if len(results) == 0:
+ raise ValueError("No seeded tenant found")
+ return results[0]["id"]
+
+ def _repr_html_(self) -> str:
+ """Return an HTML representation of the instance with a link to the URL."""
+ link = _get_link_stem(self.api_url)
+ return f'LangChain+ Client'
+
+ def __repr__(self) -> str:
+ """Return a string representation of the instance with a link to the URL."""
+ return f"LangChainPlusClient (API URL: {self.api_url})"
+
+ @property
+ def _headers(self) -> Dict[str, str]:
+ """Get the headers for the API request."""
+ headers = {}
+ if self.api_key:
+ headers["authorization"] = f"Bearer {self.api_key}"
+ return headers
+
+ @property
+ def query_params(self) -> Dict[str, str]:
+ """Get the headers for the API request."""
+ return {"tenant_id": self.tenant_id}
+
+ def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
+ """Make a GET request."""
+ query_params = self.query_params
+ if params:
+ query_params.update(params)
+ return requests.get(
+ f"{self.api_url}{path}", headers=self._headers, params=query_params
+ )
+
+ def upload_dataframe(
+ self,
+ df: pd.DataFrame,
+ name: str,
+ description: str,
+ input_keys: List[str],
+ output_keys: List[str],
+ ) -> Dataset:
+ """Upload a dataframe as individual examples to the LangChain+ API."""
+ dataset = self.create_dataset(dataset_name=name, description=description)
+ for row in df.itertuples():
+ inputs = {key: getattr(row, key) for key in input_keys}
+ outputs = {key: getattr(row, key) for key in output_keys}
+ self.create_example(inputs, outputs=outputs, dataset_id=dataset.id)
+ return dataset
+
+ def upload_csv(
+ self,
+ csv_file: Union[str, Tuple[str, BytesIO]],
+ description: str,
+ input_keys: List[str],
+ output_keys: List[str],
+ ) -> Dataset:
+ """Upload a CSV file to the LangChain+ API."""
+ files = {"file": csv_file}
+ data = {
+ "input_keys": ",".join(input_keys),
+ "output_keys": ",".join(output_keys),
+ "description": description,
+ "tenant_id": self.tenant_id,
+ }
+ response = requests.post(
+ self.api_url + "/datasets/upload",
+ headers=self._headers,
+ data=data,
+ files=files,
+ )
+ raise_for_status_with_text(response)
+ result = response.json()
+ # TODO: Make this more robust server-side
+ if "detail" in result and "already exists" in result["detail"]:
+ file_name = csv_file if isinstance(csv_file, str) else csv_file[0]
+ file_name = file_name.split("/")[-1]
+ raise ValueError(f"Dataset {file_name} already exists")
+ return Dataset(**result)
+
+ def create_dataset(self, dataset_name: str, description: str) -> Dataset:
+ """Create a dataset in the LangChain+ API."""
+ dataset = DatasetCreate(
+ tenant_id=self.tenant_id,
+ name=dataset_name,
+ description=description,
+ )
+ response = requests.post(
+ self.api_url + "/datasets",
+ headers=self._headers,
+ data=dataset.json(),
+ )
+ raise_for_status_with_text(response)
+ return Dataset(**response.json())
+
+ @xor_args(("dataset_name", "dataset_id"))
+ def read_dataset(
+ self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
+ ) -> Dataset:
+ path = "/datasets"
+ params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
+ if dataset_id is not None:
+ path += f"/{dataset_id}"
+ elif dataset_name is not None:
+ params["name"] = dataset_name
+ else:
+ raise ValueError("Must provide dataset_name or dataset_id")
+ response = self._get(
+ path,
+ params=params,
+ )
+ raise_for_status_with_text(response)
+ result = response.json()
+ if isinstance(result, list):
+ if len(result) == 0:
+ raise ValueError(f"Dataset {dataset_name} not found")
+ return Dataset(**result[0])
+ return Dataset(**result)
+
+ def list_datasets(self, limit: int = 100) -> Iterable[Dataset]:
+ """List the datasets on the LangChain+ API."""
+ response = self._get("/datasets", params={"limit": limit})
+ raise_for_status_with_text(response)
+ return [Dataset(**dataset) for dataset in response.json()]
+
+ @xor_args(("dataset_id", "dataset_name"))
+ def delete_dataset(
+ self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
+ ) -> Dataset:
+ """Delete a dataset by ID or name."""
+ if dataset_name is not None:
+ dataset_id = self.read_dataset(dataset_name=dataset_name).id
+ if dataset_id is None:
+ raise ValueError("Must provide either dataset name or ID")
+ response = requests.delete(
+ f"{self.api_url}/datasets/{dataset_id}",
+ headers=self._headers,
+ )
+ raise_for_status_with_text(response)
+ return response.json()
+
+ @xor_args(("dataset_id", "dataset_name"))
+ def create_example(
+ self,
+ inputs: Dict[str, Any],
+ dataset_id: Optional[UUID] = None,
+ dataset_name: Optional[str] = None,
+ created_at: Optional[datetime] = None,
+ outputs: Dict[str, Any] | None = None,
+ ) -> Example:
+ """Create a dataset example in the LangChain+ API."""
+ if dataset_id is None:
+ dataset_id = self.read_dataset(dataset_name).id
+
+ data = {
+ "inputs": inputs,
+ "outputs": outputs,
+ "dataset_id": dataset_id,
+ }
+ if created_at:
+ data["created_at"] = created_at.isoformat()
+ example = ExampleCreate(**data)
+ response = requests.post(
+ f"{self.api_url}/examples", headers=self._headers, data=example.json()
+ )
+ raise_for_status_with_text(response)
+ result = response.json()
+ return Example(**result)
+
+ def read_example(self, example_id: str) -> Example:
+ """Read an example from the LangChain+ API."""
+ response = self._get(f"/examples/{example_id}")
+ raise_for_status_with_text(response)
+ return Example(**response.json())
+
+ def list_examples(
+ self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
+ ) -> Iterable[Example]:
+ """List the datasets on the LangChain+ API."""
+ params = {}
+ if dataset_id is not None:
+ params["dataset"] = dataset_id
+ elif dataset_name is not None:
+ dataset_id = self.read_dataset(dataset_name=dataset_name).id
+ params["dataset"] = dataset_id
+ else:
+ pass
+ response = self._get("/examples", params=params)
+ raise_for_status_with_text(response)
+ return [Example(**dataset) for dataset in response.json()]
+
+ @staticmethod
+ async def _arun_llm(
+ llm: BaseLanguageModel,
+ inputs: Dict[str, Any],
+ langchain_tracer: LangChainTracerV2,
+ ) -> Union[LLMResult, ChatResult]:
+ if isinstance(llm, BaseLLM):
+ llm_prompts: List[str] = inputs["prompts"]
+ llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
+ elif isinstance(llm, BaseChatModel):
+ chat_prompts: List[str] = inputs["prompts"]
+ messages = [
+ parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
+ ]
+ llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
+ else:
+ raise ValueError(f"Unsupported LLM type {type(llm)}")
+ return llm_output
+
+ @staticmethod
+ async def _arun_llm_or_chain(
+ example: Example,
+ langchain_tracer: LangChainTracerV2,
+ llm_or_chain: Union[Chain, BaseLanguageModel],
+ n_repetitions: int,
+ ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
+ """Run the chain asynchronously."""
+ previous_example_id = langchain_tracer.example_id
+ langchain_tracer.example_id = example.id
+ outputs = []
+ for _ in range(n_repetitions):
+ try:
+ if isinstance(llm_or_chain, BaseLanguageModel):
+ output: Any = await LangChainPlusClient._arun_llm(
+ llm_or_chain, example.inputs, langchain_tracer
+ )
+ else:
+ output = await llm_or_chain.arun(
+ example.inputs, callbacks=[langchain_tracer]
+ )
+ outputs.append(output)
+ except Exception as e:
+ logger.warning(f"Chain failed for example {example.id}. Error: {e}")
+ outputs.append({"Error": str(e)})
+ finally:
+ langchain_tracer.example_id = previous_example_id
+ return outputs
+
+ @staticmethod
+ async def _gather_with_concurrency(
+ n: int,
+ initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracerV2, Dict]]],
+ *async_funcs: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]],
+ ) -> List[Any]:
+ """
+ Run coroutines with a concurrency limit.
+
+ Args:
+ n: The maximum number of concurrent tasks.
+ initializer: A coroutine that initializes shared resources for the tasks.
+ async_funcs: The async_funcs to be run concurrently.
+
+ Returns:
+ A list of results from the coroutines.
+ """
+ semaphore = asyncio.Semaphore(n)
+ tracer, job_state = await initializer()
+
+ async def run_coroutine_with_semaphore(
+ async_func: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]]
+ ) -> Any:
+ async with semaphore:
+ return await async_func(tracer, job_state)
+
+ return await asyncio.gather(
+ *(run_coroutine_with_semaphore(function) for function in async_funcs)
+ )
+
+ async def _tracer_initializer(
+ self, session_name: str
+ ) -> Tuple[LangChainTracerV2, dict]:
+ """
+ Initialize a tracer to share across tasks.
+
+ Args:
+ session_name: The session name for the tracer.
+
+ Returns:
+ A LangChainTracerV2 instance with an active session.
+ """
+ job_state = {"num_processed": 0}
+ with tracing_v2_enabled(session_name=session_name) as session:
+ tracer = LangChainTracerV2()
+ tracer.session = session
+ return tracer, job_state
+
+ async def arun_on_dataset(
+ self,
+ dataset_name: str,
+ llm_or_chain: Union[Chain, BaseLanguageModel],
+ concurrency_level: int = 5,
+ num_repetitions: int = 1,
+ session_name: Optional[str] = None,
+ verbose: bool = False,
+ ) -> Dict[str, Any]:
+ """
+ Run the chain on a dataset and store traces to the specified session name.
+
+ Args:
+ dataset_name: Name of the dataset to run the chain on.
+ llm_or_chain: Chain or language model to run over the dataset.
+ concurrency_level: The number of async tasks to run concurrently.
+ num_repetitions: Number of times to run the model on each example.
+ This is useful when testing success rates or generating confidence
+ intervals.
+ session_name: Name of the session to store the traces in.
+ Defaults to {dataset_name}-{chain class name}-{datetime}.
+ verbose: Whether to print progress.
+
+ Returns:
+ A dictionary mapping example ids to the model outputs.
+ """
+ if session_name is None:
+ current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ session_name = (
+ f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
+ )
+ dataset = self.read_dataset(dataset_name=dataset_name)
+ examples = self.list_examples(dataset_id=str(dataset.id))
+ results: Dict[str, List[Any]] = {}
+
+ async def process_example(
+ example: Example, tracer: LangChainTracerV2, job_state: dict
+ ) -> None:
+ """Process a single example."""
+ result = await LangChainPlusClient._arun_llm_or_chain(
+ example,
+ tracer,
+ llm_or_chain,
+ num_repetitions,
+ )
+ results[str(example.id)] = result
+ job_state["num_processed"] += 1
+ if verbose:
+ print(
+ f"Processed examples: {job_state['num_processed']}",
+ end="\r",
+ flush=True,
+ )
+
+ await self._gather_with_concurrency(
+ concurrency_level,
+ functools.partial(self._tracer_initializer, session_name),
+ *(functools.partial(process_example, e) for e in examples),
+ )
+ return results
+
+ @staticmethod
+ def run_llm(
+ llm: BaseLanguageModel,
+ inputs: Dict[str, Any],
+ langchain_tracer: LangChainTracerV2,
+ ) -> Union[LLMResult, ChatResult]:
+ """Run the language model on the example."""
+ if isinstance(llm, BaseLLM):
+ llm_prompts: List[str] = inputs["prompts"]
+ llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
+ elif isinstance(llm, BaseChatModel):
+ chat_prompts: List[str] = inputs["prompts"]
+ messages = [
+ parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
+ ]
+ llm_output = llm.generate(messages, callbacks=[langchain_tracer])
+ else:
+ raise ValueError(f"Unsupported LLM type {type(llm)}")
+ return llm_output
+
+ @staticmethod
+ def run_llm_or_chain(
+ example: Example,
+ langchain_tracer: LangChainTracerV2,
+ llm_or_chain: Union[Chain, BaseLanguageModel],
+ n_repetitions: int,
+ ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
+ """Run the chain synchronously."""
+ previous_example_id = langchain_tracer.example_id
+ langchain_tracer.example_id = example.id
+ outputs = []
+ for _ in range(n_repetitions):
+ try:
+ if isinstance(llm_or_chain, BaseLanguageModel):
+ output: Any = LangChainPlusClient.run_llm(
+ llm_or_chain, example.inputs, langchain_tracer
+ )
+ else:
+ output = llm_or_chain.run(
+ example.inputs, callbacks=[langchain_tracer]
+ )
+ outputs.append(output)
+ except Exception as e:
+ logger.warning(f"Chain failed for example {example.id}. Error: {e}")
+ outputs.append({"Error": str(e)})
+ finally:
+ langchain_tracer.example_id = previous_example_id
+ return outputs
+
+ def run_on_dataset(
+ self,
+ dataset_name: str,
+ llm_or_chain: Union[Chain, BaseLanguageModel],
+ num_repetitions: int = 1,
+ session_name: Optional[str] = None,
+ verbose: bool = False,
+ ) -> Dict[str, Any]:
+ """Run the chain on a dataset and store traces to the specified session name.
+
+ Args:
+ dataset_name: Name of the dataset to run the chain on.
+ llm_or_chain: Chain or language model to run over the dataset.
+ concurrency_level: Number of async workers to run in parallel.
+ num_repetitions: Number of times to run the model on each example.
+ This is useful when testing success rates or generating confidence
+ intervals.
+ session_name: Name of the session to store the traces in.
+ Defaults to {dataset_name}-{chain class name}-{datetime}.
+ verbose: Whether to print progress.
+
+ Returns:
+ A dictionary mapping example ids to the model outputs.
+ """
+ if session_name is None:
+ current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ session_name = (
+ f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
+ )
+ dataset = self.read_dataset(dataset_name=dataset_name)
+ examples = self.list_examples(dataset_id=str(dataset.id))
+ results: Dict[str, Any] = {}
+ with tracing_v2_enabled(session_name=session_name) as session:
+ tracer = LangChainTracerV2()
+ tracer.session = session
+
+ for i, example in enumerate(examples):
+ result = self.run_llm_or_chain(
+ example,
+ tracer,
+ llm_or_chain,
+ num_repetitions,
+ )
+ if verbose:
+ print(f"{i+1} processed", flush=True, end="\r")
+ results[str(example.id)] = result
+ return results
diff --git a/langchain/client/models.py b/langchain/client/models.py
new file mode 100644
index 0000000000..a7a19b10ae
--- /dev/null
+++ b/langchain/client/models.py
@@ -0,0 +1,54 @@
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+from uuid import UUID
+
+from pydantic import BaseModel, Field
+
+from langchain.callbacks.tracers.schemas import Run
+
+
+class ExampleBase(BaseModel):
+ """Example base model."""
+
+ dataset_id: UUID
+ inputs: Dict[str, Any]
+ outputs: Optional[Dict[str, Any]] = Field(default=None)
+
+
+class ExampleCreate(ExampleBase):
+ """Example create model."""
+
+ id: Optional[UUID]
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+
+
+class Example(ExampleBase):
+ """Example model."""
+
+ id: UUID
+ created_at: datetime
+ modified_at: Optional[datetime] = Field(default=None)
+ runs: List[Run] = Field(default_factory=list)
+
+
+class DatasetBase(BaseModel):
+ """Dataset base model."""
+
+ tenant_id: UUID
+ name: str
+ description: str
+
+
+class DatasetCreate(DatasetBase):
+ """Dataset create model."""
+
+ id: Optional[UUID]
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+
+
+class Dataset(DatasetBase):
+ """Dataset ORM model."""
+
+ id: UUID
+ created_at: datetime
+ modified_at: Optional[datetime] = Field(default=None)
diff --git a/langchain/client/utils.py b/langchain/client/utils.py
new file mode 100644
index 0000000000..f7ce264ce8
--- /dev/null
+++ b/langchain/client/utils.py
@@ -0,0 +1,42 @@
+"""Client Utils."""
+import re
+from typing import Dict, List, Optional, Sequence, Type, Union
+
+from langchain.schema import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+
+_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
+_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
+ "Human": HumanMessage,
+ "AI": AIMessage,
+ "System": SystemMessage,
+}
+
+
+def parse_chat_messages(
+ input_text: str, roles: Optional[Sequence[str]] = None
+) -> List[BaseMessage]:
+ """Parse chat messages from a string. This is not robust."""
+ roles = roles or ["Human", "AI", "System"]
+ roles_pattern = "|".join(roles)
+ pattern = (
+ rf"(?P{roles_pattern}): (?P"
+ rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
+ )
+ matches = re.finditer(pattern, input_text, re.MULTILINE)
+
+ results: List[BaseMessage] = []
+ for match in matches:
+ entity = match.group("entity")
+ message = match.group("message").rstrip("\n")
+ if entity in _RESOLUTION_MAP:
+ results.append(_RESOLUTION_MAP[entity](content=message))
+ else:
+ results.append(ChatMessage(role=entity, content=message))
+
+ return results
diff --git a/tests/unit_tests/client/__init__.py b/tests/unit_tests/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py
new file mode 100644
index 0000000000..efa0497179
--- /dev/null
+++ b/tests/unit_tests/client/test_langchain.py
@@ -0,0 +1,234 @@
+"""Test the LangChain+ client."""
+import uuid
+from datetime import datetime
+from io import BytesIO
+from typing import Any, Dict, List, Optional, Union
+from unittest import mock
+
+import pytest
+
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.tracers.langchain import LangChainTracerV2
+from langchain.callbacks.tracers.schemas import TracerSessionV2
+from langchain.chains.base import Chain
+from langchain.client.langchain import (
+ LangChainPlusClient,
+ _get_link_stem,
+ _is_localhost,
+)
+from langchain.client.models import Dataset, Example
+
+_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
+_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
+
+
+@pytest.mark.parametrize(
+ "api_url, expected_url",
+ [
+ ("http://localhost:8000", "http://localhost"),
+ ("http://www.example.com", "http://www.example.com"),
+ (
+ "https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
+ "https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
+ ),
+ ("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
+ ],
+)
+def test_link_split(api_url: str, expected_url: str) -> None:
+ """Test the link splitting handles both localhost and deployed urls."""
+ assert _get_link_stem(api_url) == expected_url
+
+
+def test_is_localhost() -> None:
+ assert _is_localhost("http://localhost:8000")
+ assert _is_localhost("http://127.0.0.1:8000")
+ assert _is_localhost("http://0.0.0.0:8000")
+ assert not _is_localhost("http://example.com:8000")
+
+
+def test_validate_api_key_if_hosted() -> None:
+ def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
+ return _TENANT_ID
+
+ with mock.patch.object(
+ LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
+ ):
+ with pytest.raises(ValueError, match="API key must be provided"):
+ LangChainPlusClient(api_url="http://www.example.com")
+
+ with mock.patch.object(
+ LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
+ ):
+ client = LangChainPlusClient(api_url="http://localhost:8000")
+ assert client.api_url == "http://localhost:8000"
+ assert client.api_key is None
+
+
+def test_headers() -> None:
+ def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
+ return _TENANT_ID
+
+ with mock.patch.object(
+ LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
+ ):
+ client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
+ assert client._headers == {"authorization": "Bearer 123"}
+
+ with mock.patch.object(
+ LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
+ ):
+ client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
+ assert client_no_key._headers == {}
+
+
+@mock.patch("langchain.client.langchain.requests.post")
+def test_upload_csv(mock_post: mock.Mock) -> None:
+ mock_response = mock.Mock()
+ dataset_id = str(uuid.uuid4())
+ example_1 = Example(
+ id=str(uuid.uuid4()),
+ created_at=_CREATED_AT,
+ inputs={"input": "1"},
+ outputs={"output": "2"},
+ dataset_id=dataset_id,
+ )
+ example_2 = Example(
+ id=str(uuid.uuid4()),
+ created_at=_CREATED_AT,
+ inputs={"input": "3"},
+ outputs={"output": "4"},
+ dataset_id=dataset_id,
+ )
+
+ mock_response.json.return_value = {
+ "id": dataset_id,
+ "name": "test.csv",
+ "description": "Test dataset",
+ "owner_id": "the owner",
+ "created_at": _CREATED_AT,
+ "examples": [example_1, example_2],
+ "tenant_id": _TENANT_ID,
+ }
+ mock_post.return_value = mock_response
+
+ client = LangChainPlusClient(
+ api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
+ )
+ csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
+
+ dataset = client.upload_csv(
+ csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
+ )
+
+ assert dataset.id == uuid.UUID(dataset_id)
+ assert dataset.name == "test.csv"
+ assert dataset.description == "Test dataset"
+
+
+@pytest.mark.asyncio
+async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
+ dataset = Dataset(
+ id=uuid.uuid4(),
+ name="test",
+ description="Test dataset",
+ owner_id="owner",
+ created_at=_CREATED_AT,
+ tenant_id=_TENANT_ID,
+ )
+ uuids = [
+ "0c193153-2309-4704-9a47-17aee4fb25c8",
+ "0d11b5fd-8e66-4485-b696-4b55155c0c05",
+ "90d696f0-f10d-4fd0-b88b-bfee6df08b84",
+ "4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
+ "7b5a524c-80fa-4960-888e-7d380f9a11ee",
+ ]
+ examples = [
+ Example(
+ id=uuids[0],
+ created_at=_CREATED_AT,
+ inputs={"input": "1"},
+ outputs={"output": "2"},
+ dataset_id=str(uuid.uuid4()),
+ ),
+ Example(
+ id=uuids[1],
+ created_at=_CREATED_AT,
+ inputs={"input": "3"},
+ outputs={"output": "4"},
+ dataset_id=str(uuid.uuid4()),
+ ),
+ Example(
+ id=uuids[2],
+ created_at=_CREATED_AT,
+ inputs={"input": "5"},
+ outputs={"output": "6"},
+ dataset_id=str(uuid.uuid4()),
+ ),
+ Example(
+ id=uuids[3],
+ created_at=_CREATED_AT,
+ inputs={"input": "7"},
+ outputs={"output": "8"},
+ dataset_id=str(uuid.uuid4()),
+ ),
+ Example(
+ id=uuids[4],
+ created_at=_CREATED_AT,
+ inputs={"input": "9"},
+ outputs={"output": "10"},
+ dataset_id=str(uuid.uuid4()),
+ ),
+ ]
+
+ def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
+ return dataset
+
+ def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
+ return examples
+
+ async def mock_arun_chain(
+ example: Example,
+ tracer: Any,
+ llm_or_chain: Union[BaseLanguageModel, Chain],
+ n_repetitions: int,
+ ) -> List[Dict[str, Any]]:
+ return [
+ {"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
+ ]
+
+ def mock_load_session(
+ self: Any, name: str, *args: Any, **kwargs: Any
+ ) -> TracerSessionV2:
+ return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
+
+ with mock.patch.object(
+ LangChainPlusClient, "read_dataset", new=mock_read_dataset
+ ), mock.patch.object(
+ LangChainPlusClient, "list_examples", new=mock_list_examples
+ ), mock.patch.object(
+ LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
+ ), mock.patch.object(
+ LangChainTracerV2, "load_session", new=mock_load_session
+ ):
+ monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
+ client = LangChainPlusClient(
+ api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
+ )
+ chain = mock.MagicMock()
+ num_repetitions = 3
+ results = await client.arun_on_dataset(
+ dataset_name="test",
+ llm_or_chain=chain,
+ concurrency_level=2,
+ session_name="test_session",
+ num_repetitions=num_repetitions,
+ )
+
+ expected = {
+ uuid_: [
+ {"result": f"Result for example {uuid.UUID(uuid_)}"}
+ for _ in range(num_repetitions)
+ ]
+ for uuid_ in uuids
+ }
+ assert results == expected
diff --git a/tests/unit_tests/client/test_utils.py b/tests/unit_tests/client/test_utils.py
new file mode 100644
index 0000000000..7ee405c55f
--- /dev/null
+++ b/tests/unit_tests/client/test_utils.py
@@ -0,0 +1,70 @@
+"""Test LangChain+ Client Utils."""
+
+from typing import List
+
+from langchain.client.utils import parse_chat_messages
+from langchain.schema import (
+ AIMessage,
+ BaseMessage,
+ ChatMessage,
+ HumanMessage,
+ SystemMessage,
+)
+
+
+def test_parse_chat_messages() -> None:
+ """Test that chat messages are parsed correctly."""
+ input_text = (
+ "Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
+ )
+ expected = [
+ HumanMessage(content="I am human roar"),
+ AIMessage(content="I am AI beep boop"),
+ SystemMessage(content="I am a system message"),
+ ]
+ assert parse_chat_messages(input_text) == expected
+
+
+def test_parse_chat_messages_empty_input() -> None:
+ """Test that an empty input string returns an empty list."""
+ input_text = ""
+ expected: List[BaseMessage] = []
+ assert parse_chat_messages(input_text) == expected
+
+
+def test_parse_chat_messages_multiline_messages() -> None:
+ """Test that multiline messages are parsed correctly."""
+ input_text = (
+ "Human: I am a human\nand I roar\nAI: I am an AI\nand I"
+ " beep boop\nSystem: I am a system\nand a message"
+ )
+ expected = [
+ HumanMessage(content="I am a human\nand I roar"),
+ AIMessage(content="I am an AI\nand I beep boop"),
+ SystemMessage(content="I am a system\nand a message"),
+ ]
+ assert parse_chat_messages(input_text) == expected
+
+
+def test_parse_chat_messages_custom_roles() -> None:
+ """Test that custom roles are parsed correctly."""
+ input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
+ expected = [
+ ChatMessage(role="Client", content="I need help"),
+ ChatMessage(role="Agent", content="I'm here to help"),
+ ChatMessage(role="Client", content="Thank you"),
+ ]
+ assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
+
+
+def test_parse_chat_messages_embedded_roles() -> None:
+ """Test that messages with embedded role references are parsed correctly."""
+ input_text = (
+ "Human: Oh ai what if you said AI: foo bar?"
+ "\nAI: Well, that would be interesting!"
+ )
+ expected = [
+ HumanMessage(content="Oh ai what if you said AI: foo bar?"),
+ AIMessage(content="Well, that would be interesting!"),
+ ]
+ assert parse_chat_messages(input_text) == expected