diff --git a/langchain/client/__init__.py b/langchain/client/__init__.py new file mode 100644 index 00000000..89a34615 --- /dev/null +++ b/langchain/client/__init__.py @@ -0,0 +1,6 @@ +"""LangChain+ Client.""" + + +from langchain.client.langchain import LangChainPlusClient + +__all__ = ["LangChainPlusClient"] diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py new file mode 100644 index 00000000..5e01e190 --- /dev/null +++ b/langchain/client/langchain.py @@ -0,0 +1,544 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import socket +from datetime import datetime +from io import BytesIO +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) +from urllib.parse import urlsplit +from uuid import UUID + +import requests +from pydantic import BaseSettings, Field, root_validator +from requests import Response + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import tracing_v2_enabled +from langchain.callbacks.tracers.langchain import LangChainTracerV2 +from langchain.chains.base import Chain +from langchain.chat_models.base import BaseChatModel +from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate +from langchain.client.utils import parse_chat_messages +from langchain.llms.base import BaseLLM +from langchain.schema import ChatResult, LLMResult +from langchain.utils import raise_for_status_with_text, xor_args + +if TYPE_CHECKING: + import pandas as pd + +logger = logging.getLogger(__name__) + + +def _get_link_stem(url: str) -> str: + scheme = urlsplit(url).scheme + netloc_prefix = urlsplit(url).netloc.split(":")[0] + return f"{scheme}://{netloc_prefix}" + + +def _is_localhost(url: str) -> bool: + """Check if the URL is localhost.""" + try: + netloc = urlsplit(url).netloc.split(":")[0] + ip = socket.gethostbyname(netloc) + return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::") + except socket.gaierror: + return False + + +class LangChainPlusClient(BaseSettings): + """Client for interacting with the LangChain+ API.""" + + api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY") + api_url: str = Field(..., env="LANGCHAIN_ENDPOINT") + tenant_id: str = Field(..., env="LANGCHAIN_TENANT_ID") + + @root_validator(pre=True) + def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Verify API key is provided if url not localhost.""" + api_url: str = values.get("api_url", "http://localhost:8000") + api_key: Optional[str] = values.get("api_key") + if not _is_localhost(api_url): + if not api_key: + raise ValueError( + "API key must be provided when using hosted LangChain+ API" + ) + else: + tenant_id = values.get("tenant_id") + if not tenant_id: + values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id( + api_url, api_key + ) + return values + + @staticmethod + def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: + """Get the tenant ID from the seeded tenant.""" + url = f"{api_url}/tenants" + headers = {"authorization": f"Bearer {api_key}"} if api_key else {} + response = requests.get(url, headers=headers) + try: + raise_for_status_with_text(response) + except Exception as e: + raise ValueError( + "Unable to get seeded tenant ID. Please manually provide." + ) from e + results: List[dict] = response.json() + breakpoint() + if len(results) == 0: + raise ValueError("No seeded tenant found") + return results[0]["id"] + + def _repr_html_(self) -> str: + """Return an HTML representation of the instance with a link to the URL.""" + link = _get_link_stem(self.api_url) + return f'LangChain+ Client' + + def __repr__(self) -> str: + """Return a string representation of the instance with a link to the URL.""" + return f"LangChainPlusClient (API URL: {self.api_url})" + + @property + def _headers(self) -> Dict[str, str]: + """Get the headers for the API request.""" + headers = {} + if self.api_key: + headers["authorization"] = f"Bearer {self.api_key}" + return headers + + @property + def query_params(self) -> Dict[str, str]: + """Get the headers for the API request.""" + return {"tenant_id": self.tenant_id} + + def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response: + """Make a GET request.""" + query_params = self.query_params + if params: + query_params.update(params) + return requests.get( + f"{self.api_url}{path}", headers=self._headers, params=query_params + ) + + def upload_dataframe( + self, + df: pd.DataFrame, + name: str, + description: str, + input_keys: List[str], + output_keys: List[str], + ) -> Dataset: + """Upload a dataframe as individual examples to the LangChain+ API.""" + dataset = self.create_dataset(dataset_name=name, description=description) + for row in df.itertuples(): + inputs = {key: getattr(row, key) for key in input_keys} + outputs = {key: getattr(row, key) for key in output_keys} + self.create_example(inputs, outputs=outputs, dataset_id=dataset.id) + return dataset + + def upload_csv( + self, + csv_file: Union[str, Tuple[str, BytesIO]], + description: str, + input_keys: List[str], + output_keys: List[str], + ) -> Dataset: + """Upload a CSV file to the LangChain+ API.""" + files = {"file": csv_file} + data = { + "input_keys": ",".join(input_keys), + "output_keys": ",".join(output_keys), + "description": description, + "tenant_id": self.tenant_id, + } + response = requests.post( + self.api_url + "/datasets/upload", + headers=self._headers, + data=data, + files=files, + ) + raise_for_status_with_text(response) + result = response.json() + # TODO: Make this more robust server-side + if "detail" in result and "already exists" in result["detail"]: + file_name = csv_file if isinstance(csv_file, str) else csv_file[0] + file_name = file_name.split("/")[-1] + raise ValueError(f"Dataset {file_name} already exists") + return Dataset(**result) + + def create_dataset(self, dataset_name: str, description: str) -> Dataset: + """Create a dataset in the LangChain+ API.""" + dataset = DatasetCreate( + tenant_id=self.tenant_id, + name=dataset_name, + description=description, + ) + response = requests.post( + self.api_url + "/datasets", + headers=self._headers, + data=dataset.json(), + ) + raise_for_status_with_text(response) + return Dataset(**response.json()) + + @xor_args(("dataset_name", "dataset_id")) + def read_dataset( + self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None + ) -> Dataset: + path = "/datasets" + params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id} + if dataset_id is not None: + path += f"/{dataset_id}" + elif dataset_name is not None: + params["name"] = dataset_name + else: + raise ValueError("Must provide dataset_name or dataset_id") + response = self._get( + path, + params=params, + ) + raise_for_status_with_text(response) + result = response.json() + if isinstance(result, list): + if len(result) == 0: + raise ValueError(f"Dataset {dataset_name} not found") + return Dataset(**result[0]) + return Dataset(**result) + + def list_datasets(self, limit: int = 100) -> Iterable[Dataset]: + """List the datasets on the LangChain+ API.""" + response = self._get("/datasets", params={"limit": limit}) + raise_for_status_with_text(response) + return [Dataset(**dataset) for dataset in response.json()] + + @xor_args(("dataset_id", "dataset_name")) + def delete_dataset( + self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None + ) -> Dataset: + """Delete a dataset by ID or name.""" + if dataset_name is not None: + dataset_id = self.read_dataset(dataset_name=dataset_name).id + if dataset_id is None: + raise ValueError("Must provide either dataset name or ID") + response = requests.delete( + f"{self.api_url}/datasets/{dataset_id}", + headers=self._headers, + ) + raise_for_status_with_text(response) + return response.json() + + @xor_args(("dataset_id", "dataset_name")) + def create_example( + self, + inputs: Dict[str, Any], + dataset_id: Optional[UUID] = None, + dataset_name: Optional[str] = None, + created_at: Optional[datetime] = None, + outputs: Dict[str, Any] | None = None, + ) -> Example: + """Create a dataset example in the LangChain+ API.""" + if dataset_id is None: + dataset_id = self.read_dataset(dataset_name).id + + data = { + "inputs": inputs, + "outputs": outputs, + "dataset_id": dataset_id, + } + if created_at: + data["created_at"] = created_at.isoformat() + example = ExampleCreate(**data) + response = requests.post( + f"{self.api_url}/examples", headers=self._headers, data=example.json() + ) + raise_for_status_with_text(response) + result = response.json() + return Example(**result) + + def read_example(self, example_id: str) -> Example: + """Read an example from the LangChain+ API.""" + response = self._get(f"/examples/{example_id}") + raise_for_status_with_text(response) + return Example(**response.json()) + + def list_examples( + self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None + ) -> Iterable[Example]: + """List the datasets on the LangChain+ API.""" + params = {} + if dataset_id is not None: + params["dataset"] = dataset_id + elif dataset_name is not None: + dataset_id = self.read_dataset(dataset_name=dataset_name).id + params["dataset"] = dataset_id + else: + pass + response = self._get("/examples", params=params) + raise_for_status_with_text(response) + return [Example(**dataset) for dataset in response.json()] + + @staticmethod + async def _arun_llm( + llm: BaseLanguageModel, + inputs: Dict[str, Any], + langchain_tracer: LangChainTracerV2, + ) -> Union[LLMResult, ChatResult]: + if isinstance(llm, BaseLLM): + llm_prompts: List[str] = inputs["prompts"] + llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer]) + elif isinstance(llm, BaseChatModel): + chat_prompts: List[str] = inputs["prompts"] + messages = [ + parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts + ] + llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer]) + else: + raise ValueError(f"Unsupported LLM type {type(llm)}") + return llm_output + + @staticmethod + async def _arun_llm_or_chain( + example: Example, + langchain_tracer: LangChainTracerV2, + llm_or_chain: Union[Chain, BaseLanguageModel], + n_repetitions: int, + ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: + """Run the chain asynchronously.""" + previous_example_id = langchain_tracer.example_id + langchain_tracer.example_id = example.id + outputs = [] + for _ in range(n_repetitions): + try: + if isinstance(llm_or_chain, BaseLanguageModel): + output: Any = await LangChainPlusClient._arun_llm( + llm_or_chain, example.inputs, langchain_tracer + ) + else: + output = await llm_or_chain.arun( + example.inputs, callbacks=[langchain_tracer] + ) + outputs.append(output) + except Exception as e: + logger.warning(f"Chain failed for example {example.id}. Error: {e}") + outputs.append({"Error": str(e)}) + finally: + langchain_tracer.example_id = previous_example_id + return outputs + + @staticmethod + async def _gather_with_concurrency( + n: int, + initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracerV2, Dict]]], + *async_funcs: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]], + ) -> List[Any]: + """ + Run coroutines with a concurrency limit. + + Args: + n: The maximum number of concurrent tasks. + initializer: A coroutine that initializes shared resources for the tasks. + async_funcs: The async_funcs to be run concurrently. + + Returns: + A list of results from the coroutines. + """ + semaphore = asyncio.Semaphore(n) + tracer, job_state = await initializer() + + async def run_coroutine_with_semaphore( + async_func: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]] + ) -> Any: + async with semaphore: + return await async_func(tracer, job_state) + + return await asyncio.gather( + *(run_coroutine_with_semaphore(function) for function in async_funcs) + ) + + async def _tracer_initializer( + self, session_name: str + ) -> Tuple[LangChainTracerV2, dict]: + """ + Initialize a tracer to share across tasks. + + Args: + session_name: The session name for the tracer. + + Returns: + A LangChainTracerV2 instance with an active session. + """ + job_state = {"num_processed": 0} + with tracing_v2_enabled(session_name=session_name) as session: + tracer = LangChainTracerV2() + tracer.session = session + return tracer, job_state + + async def arun_on_dataset( + self, + dataset_name: str, + llm_or_chain: Union[Chain, BaseLanguageModel], + concurrency_level: int = 5, + num_repetitions: int = 1, + session_name: Optional[str] = None, + verbose: bool = False, + ) -> Dict[str, Any]: + """ + Run the chain on a dataset and store traces to the specified session name. + + Args: + dataset_name: Name of the dataset to run the chain on. + llm_or_chain: Chain or language model to run over the dataset. + concurrency_level: The number of async tasks to run concurrently. + num_repetitions: Number of times to run the model on each example. + This is useful when testing success rates or generating confidence + intervals. + session_name: Name of the session to store the traces in. + Defaults to {dataset_name}-{chain class name}-{datetime}. + verbose: Whether to print progress. + + Returns: + A dictionary mapping example ids to the model outputs. + """ + if session_name is None: + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + session_name = ( + f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}" + ) + dataset = self.read_dataset(dataset_name=dataset_name) + examples = self.list_examples(dataset_id=str(dataset.id)) + results: Dict[str, List[Any]] = {} + + async def process_example( + example: Example, tracer: LangChainTracerV2, job_state: dict + ) -> None: + """Process a single example.""" + result = await LangChainPlusClient._arun_llm_or_chain( + example, + tracer, + llm_or_chain, + num_repetitions, + ) + results[str(example.id)] = result + job_state["num_processed"] += 1 + if verbose: + print( + f"Processed examples: {job_state['num_processed']}", + end="\r", + flush=True, + ) + + await self._gather_with_concurrency( + concurrency_level, + functools.partial(self._tracer_initializer, session_name), + *(functools.partial(process_example, e) for e in examples), + ) + return results + + @staticmethod + def run_llm( + llm: BaseLanguageModel, + inputs: Dict[str, Any], + langchain_tracer: LangChainTracerV2, + ) -> Union[LLMResult, ChatResult]: + """Run the language model on the example.""" + if isinstance(llm, BaseLLM): + llm_prompts: List[str] = inputs["prompts"] + llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer]) + elif isinstance(llm, BaseChatModel): + chat_prompts: List[str] = inputs["prompts"] + messages = [ + parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts + ] + llm_output = llm.generate(messages, callbacks=[langchain_tracer]) + else: + raise ValueError(f"Unsupported LLM type {type(llm)}") + return llm_output + + @staticmethod + def run_llm_or_chain( + example: Example, + langchain_tracer: LangChainTracerV2, + llm_or_chain: Union[Chain, BaseLanguageModel], + n_repetitions: int, + ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: + """Run the chain synchronously.""" + previous_example_id = langchain_tracer.example_id + langchain_tracer.example_id = example.id + outputs = [] + for _ in range(n_repetitions): + try: + if isinstance(llm_or_chain, BaseLanguageModel): + output: Any = LangChainPlusClient.run_llm( + llm_or_chain, example.inputs, langchain_tracer + ) + else: + output = llm_or_chain.run( + example.inputs, callbacks=[langchain_tracer] + ) + outputs.append(output) + except Exception as e: + logger.warning(f"Chain failed for example {example.id}. Error: {e}") + outputs.append({"Error": str(e)}) + finally: + langchain_tracer.example_id = previous_example_id + return outputs + + def run_on_dataset( + self, + dataset_name: str, + llm_or_chain: Union[Chain, BaseLanguageModel], + num_repetitions: int = 1, + session_name: Optional[str] = None, + verbose: bool = False, + ) -> Dict[str, Any]: + """Run the chain on a dataset and store traces to the specified session name. + + Args: + dataset_name: Name of the dataset to run the chain on. + llm_or_chain: Chain or language model to run over the dataset. + concurrency_level: Number of async workers to run in parallel. + num_repetitions: Number of times to run the model on each example. + This is useful when testing success rates or generating confidence + intervals. + session_name: Name of the session to store the traces in. + Defaults to {dataset_name}-{chain class name}-{datetime}. + verbose: Whether to print progress. + + Returns: + A dictionary mapping example ids to the model outputs. + """ + if session_name is None: + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + session_name = ( + f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}" + ) + dataset = self.read_dataset(dataset_name=dataset_name) + examples = self.list_examples(dataset_id=str(dataset.id)) + results: Dict[str, Any] = {} + with tracing_v2_enabled(session_name=session_name) as session: + tracer = LangChainTracerV2() + tracer.session = session + + for i, example in enumerate(examples): + result = self.run_llm_or_chain( + example, + tracer, + llm_or_chain, + num_repetitions, + ) + if verbose: + print(f"{i+1} processed", flush=True, end="\r") + results[str(example.id)] = result + return results diff --git a/langchain/client/models.py b/langchain/client/models.py new file mode 100644 index 00000000..a7a19b10 --- /dev/null +++ b/langchain/client/models.py @@ -0,0 +1,54 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +from langchain.callbacks.tracers.schemas import Run + + +class ExampleBase(BaseModel): + """Example base model.""" + + dataset_id: UUID + inputs: Dict[str, Any] + outputs: Optional[Dict[str, Any]] = Field(default=None) + + +class ExampleCreate(ExampleBase): + """Example create model.""" + + id: Optional[UUID] + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class Example(ExampleBase): + """Example model.""" + + id: UUID + created_at: datetime + modified_at: Optional[datetime] = Field(default=None) + runs: List[Run] = Field(default_factory=list) + + +class DatasetBase(BaseModel): + """Dataset base model.""" + + tenant_id: UUID + name: str + description: str + + +class DatasetCreate(DatasetBase): + """Dataset create model.""" + + id: Optional[UUID] + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class Dataset(DatasetBase): + """Dataset ORM model.""" + + id: UUID + created_at: datetime + modified_at: Optional[datetime] = Field(default=None) diff --git a/langchain/client/utils.py b/langchain/client/utils.py new file mode 100644 index 00000000..f7ce264c --- /dev/null +++ b/langchain/client/utils.py @@ -0,0 +1,42 @@ +"""Client Utils.""" +import re +from typing import Dict, List, Optional, Sequence, Type, Union + +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) + +_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]] +_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = { + "Human": HumanMessage, + "AI": AIMessage, + "System": SystemMessage, +} + + +def parse_chat_messages( + input_text: str, roles: Optional[Sequence[str]] = None +) -> List[BaseMessage]: + """Parse chat messages from a string. This is not robust.""" + roles = roles or ["Human", "AI", "System"] + roles_pattern = "|".join(roles) + pattern = ( + rf"(?P{roles_pattern}): (?P" + rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)" + ) + matches = re.finditer(pattern, input_text, re.MULTILINE) + + results: List[BaseMessage] = [] + for match in matches: + entity = match.group("entity") + message = match.group("message").rstrip("\n") + if entity in _RESOLUTION_MAP: + results.append(_RESOLUTION_MAP[entity](content=message)) + else: + results.append(ChatMessage(role=entity, content=message)) + + return results diff --git a/tests/unit_tests/client/__init__.py b/tests/unit_tests/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py new file mode 100644 index 00000000..efa04971 --- /dev/null +++ b/tests/unit_tests/client/test_langchain.py @@ -0,0 +1,234 @@ +"""Test the LangChain+ client.""" +import uuid +from datetime import datetime +from io import BytesIO +from typing import Any, Dict, List, Optional, Union +from unittest import mock + +import pytest + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.tracers.langchain import LangChainTracerV2 +from langchain.callbacks.tracers.schemas import TracerSessionV2 +from langchain.chains.base import Chain +from langchain.client.langchain import ( + LangChainPlusClient, + _get_link_stem, + _is_localhost, +) +from langchain.client.models import Dataset, Example + +_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0) +_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4" + + +@pytest.mark.parametrize( + "api_url, expected_url", + [ + ("http://localhost:8000", "http://localhost"), + ("http://www.example.com", "http://www.example.com"), + ( + "https://hosted-1234-23qwerty.f.234.foobar.gateway.dev", + "https://hosted-1234-23qwerty.f.234.foobar.gateway.dev", + ), + ("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"), + ], +) +def test_link_split(api_url: str, expected_url: str) -> None: + """Test the link splitting handles both localhost and deployed urls.""" + assert _get_link_stem(api_url) == expected_url + + +def test_is_localhost() -> None: + assert _is_localhost("http://localhost:8000") + assert _is_localhost("http://127.0.0.1:8000") + assert _is_localhost("http://0.0.0.0:8000") + assert not _is_localhost("http://example.com:8000") + + +def test_validate_api_key_if_hosted() -> None: + def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: + return _TENANT_ID + + with mock.patch.object( + LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id + ): + with pytest.raises(ValueError, match="API key must be provided"): + LangChainPlusClient(api_url="http://www.example.com") + + with mock.patch.object( + LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id + ): + client = LangChainPlusClient(api_url="http://localhost:8000") + assert client.api_url == "http://localhost:8000" + assert client.api_key is None + + +def test_headers() -> None: + def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: + return _TENANT_ID + + with mock.patch.object( + LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id + ): + client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123") + assert client._headers == {"authorization": "Bearer 123"} + + with mock.patch.object( + LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id + ): + client_no_key = LangChainPlusClient(api_url="http://localhost:8000") + assert client_no_key._headers == {} + + +@mock.patch("langchain.client.langchain.requests.post") +def test_upload_csv(mock_post: mock.Mock) -> None: + mock_response = mock.Mock() + dataset_id = str(uuid.uuid4()) + example_1 = Example( + id=str(uuid.uuid4()), + created_at=_CREATED_AT, + inputs={"input": "1"}, + outputs={"output": "2"}, + dataset_id=dataset_id, + ) + example_2 = Example( + id=str(uuid.uuid4()), + created_at=_CREATED_AT, + inputs={"input": "3"}, + outputs={"output": "4"}, + dataset_id=dataset_id, + ) + + mock_response.json.return_value = { + "id": dataset_id, + "name": "test.csv", + "description": "Test dataset", + "owner_id": "the owner", + "created_at": _CREATED_AT, + "examples": [example_1, example_2], + "tenant_id": _TENANT_ID, + } + mock_post.return_value = mock_response + + client = LangChainPlusClient( + api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID + ) + csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n")) + + dataset = client.upload_csv( + csv_file, "Test dataset", input_keys=["input"], output_keys=["output"] + ) + + assert dataset.id == uuid.UUID(dataset_id) + assert dataset.name == "test.csv" + assert dataset.description == "Test dataset" + + +@pytest.mark.asyncio +async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = Dataset( + id=uuid.uuid4(), + name="test", + description="Test dataset", + owner_id="owner", + created_at=_CREATED_AT, + tenant_id=_TENANT_ID, + ) + uuids = [ + "0c193153-2309-4704-9a47-17aee4fb25c8", + "0d11b5fd-8e66-4485-b696-4b55155c0c05", + "90d696f0-f10d-4fd0-b88b-bfee6df08b84", + "4ce2c6d8-5124-4c0c-8292-db7bdebcf167", + "7b5a524c-80fa-4960-888e-7d380f9a11ee", + ] + examples = [ + Example( + id=uuids[0], + created_at=_CREATED_AT, + inputs={"input": "1"}, + outputs={"output": "2"}, + dataset_id=str(uuid.uuid4()), + ), + Example( + id=uuids[1], + created_at=_CREATED_AT, + inputs={"input": "3"}, + outputs={"output": "4"}, + dataset_id=str(uuid.uuid4()), + ), + Example( + id=uuids[2], + created_at=_CREATED_AT, + inputs={"input": "5"}, + outputs={"output": "6"}, + dataset_id=str(uuid.uuid4()), + ), + Example( + id=uuids[3], + created_at=_CREATED_AT, + inputs={"input": "7"}, + outputs={"output": "8"}, + dataset_id=str(uuid.uuid4()), + ), + Example( + id=uuids[4], + created_at=_CREATED_AT, + inputs={"input": "9"}, + outputs={"output": "10"}, + dataset_id=str(uuid.uuid4()), + ), + ] + + def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset: + return dataset + + def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]: + return examples + + async def mock_arun_chain( + example: Example, + tracer: Any, + llm_or_chain: Union[BaseLanguageModel, Chain], + n_repetitions: int, + ) -> List[Dict[str, Any]]: + return [ + {"result": f"Result for example {example.id}"} for _ in range(n_repetitions) + ] + + def mock_load_session( + self: Any, name: str, *args: Any, **kwargs: Any + ) -> TracerSessionV2: + return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4()) + + with mock.patch.object( + LangChainPlusClient, "read_dataset", new=mock_read_dataset + ), mock.patch.object( + LangChainPlusClient, "list_examples", new=mock_list_examples + ), mock.patch.object( + LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain + ), mock.patch.object( + LangChainTracerV2, "load_session", new=mock_load_session + ): + monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID) + client = LangChainPlusClient( + api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID + ) + chain = mock.MagicMock() + num_repetitions = 3 + results = await client.arun_on_dataset( + dataset_name="test", + llm_or_chain=chain, + concurrency_level=2, + session_name="test_session", + num_repetitions=num_repetitions, + ) + + expected = { + uuid_: [ + {"result": f"Result for example {uuid.UUID(uuid_)}"} + for _ in range(num_repetitions) + ] + for uuid_ in uuids + } + assert results == expected diff --git a/tests/unit_tests/client/test_utils.py b/tests/unit_tests/client/test_utils.py new file mode 100644 index 00000000..7ee405c5 --- /dev/null +++ b/tests/unit_tests/client/test_utils.py @@ -0,0 +1,70 @@ +"""Test LangChain+ Client Utils.""" + +from typing import List + +from langchain.client.utils import parse_chat_messages +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) + + +def test_parse_chat_messages() -> None: + """Test that chat messages are parsed correctly.""" + input_text = ( + "Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message" + ) + expected = [ + HumanMessage(content="I am human roar"), + AIMessage(content="I am AI beep boop"), + SystemMessage(content="I am a system message"), + ] + assert parse_chat_messages(input_text) == expected + + +def test_parse_chat_messages_empty_input() -> None: + """Test that an empty input string returns an empty list.""" + input_text = "" + expected: List[BaseMessage] = [] + assert parse_chat_messages(input_text) == expected + + +def test_parse_chat_messages_multiline_messages() -> None: + """Test that multiline messages are parsed correctly.""" + input_text = ( + "Human: I am a human\nand I roar\nAI: I am an AI\nand I" + " beep boop\nSystem: I am a system\nand a message" + ) + expected = [ + HumanMessage(content="I am a human\nand I roar"), + AIMessage(content="I am an AI\nand I beep boop"), + SystemMessage(content="I am a system\nand a message"), + ] + assert parse_chat_messages(input_text) == expected + + +def test_parse_chat_messages_custom_roles() -> None: + """Test that custom roles are parsed correctly.""" + input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you" + expected = [ + ChatMessage(role="Client", content="I need help"), + ChatMessage(role="Agent", content="I'm here to help"), + ChatMessage(role="Client", content="Thank you"), + ] + assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected + + +def test_parse_chat_messages_embedded_roles() -> None: + """Test that messages with embedded role references are parsed correctly.""" + input_text = ( + "Human: Oh ai what if you said AI: foo bar?" + "\nAI: Well, that would be interesting!" + ) + expected = [ + HumanMessage(content="Oh ai what if you said AI: foo bar?"), + AIMessage(content="Well, that would be interesting!"), + ] + assert parse_chat_messages(input_text) == expected