mirror of https://github.com/hwchase17/langchain
Add LCP Client (#4198)
Adding a client to fetch datasets, examples, and runs from a LCP instance and run objects over them.pull/4204/head
parent
a30f42da4e
commit
1017e5cee2
@ -0,0 +1,6 @@
|
||||
"""LangChain+ Client."""
|
||||
|
||||
|
||||
from langchain.client.langchain import LangChainPlusClient
|
||||
|
||||
__all__ = ["LangChainPlusClient"]
|
@ -0,0 +1,544 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlsplit
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from pydantic import BaseSettings, Field, root_validator
|
||||
from requests import Response
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
|
||||
from langchain.client.utils import parse_chat_messages
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import ChatResult, LLMResult
|
||||
from langchain.utils import raise_for_status_with_text, xor_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_link_stem(url: str) -> str:
|
||||
scheme = urlsplit(url).scheme
|
||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||
return f"{scheme}://{netloc_prefix}"
|
||||
|
||||
|
||||
def _is_localhost(url: str) -> bool:
|
||||
"""Check if the URL is localhost."""
|
||||
try:
|
||||
netloc = urlsplit(url).netloc.split(":")[0]
|
||||
ip = socket.gethostbyname(netloc)
|
||||
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
|
||||
except socket.gaierror:
|
||||
return False
|
||||
|
||||
|
||||
class LangChainPlusClient(BaseSettings):
|
||||
"""Client for interacting with the LangChain+ API."""
|
||||
|
||||
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
||||
api_url: str = Field(..., env="LANGCHAIN_ENDPOINT")
|
||||
tenant_id: str = Field(..., env="LANGCHAIN_TENANT_ID")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Verify API key is provided if url not localhost."""
|
||||
api_url: str = values.get("api_url", "http://localhost:8000")
|
||||
api_key: Optional[str] = values.get("api_key")
|
||||
if not _is_localhost(api_url):
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key must be provided when using hosted LangChain+ API"
|
||||
)
|
||||
else:
|
||||
tenant_id = values.get("tenant_id")
|
||||
if not tenant_id:
|
||||
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
|
||||
api_url, api_key
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
"""Get the tenant ID from the seeded tenant."""
|
||||
url = f"{api_url}/tenants"
|
||||
headers = {"authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
response = requests.get(url, headers=headers)
|
||||
try:
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Unable to get seeded tenant ID. Please manually provide."
|
||||
) from e
|
||||
results: List[dict] = response.json()
|
||||
breakpoint()
|
||||
if len(results) == 0:
|
||||
raise ValueError("No seeded tenant found")
|
||||
return results[0]["id"]
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the instance with a link to the URL."""
|
||||
link = _get_link_stem(self.api_url)
|
||||
return f'<a href="{link}", target="_blank" rel="noopener">LangChain+ Client</a>'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the instance with a link to the URL."""
|
||||
return f"LangChainPlusClient (API URL: {self.api_url})"
|
||||
|
||||
@property
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
"""Get the headers for the API request."""
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
@property
|
||||
def query_params(self) -> Dict[str, str]:
|
||||
"""Get the headers for the API request."""
|
||||
return {"tenant_id": self.tenant_id}
|
||||
|
||||
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
||||
"""Make a GET request."""
|
||||
query_params = self.query_params
|
||||
if params:
|
||||
query_params.update(params)
|
||||
return requests.get(
|
||||
f"{self.api_url}{path}", headers=self._headers, params=query_params
|
||||
)
|
||||
|
||||
def upload_dataframe(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
name: str,
|
||||
description: str,
|
||||
input_keys: List[str],
|
||||
output_keys: List[str],
|
||||
) -> Dataset:
|
||||
"""Upload a dataframe as individual examples to the LangChain+ API."""
|
||||
dataset = self.create_dataset(dataset_name=name, description=description)
|
||||
for row in df.itertuples():
|
||||
inputs = {key: getattr(row, key) for key in input_keys}
|
||||
outputs = {key: getattr(row, key) for key in output_keys}
|
||||
self.create_example(inputs, outputs=outputs, dataset_id=dataset.id)
|
||||
return dataset
|
||||
|
||||
def upload_csv(
|
||||
self,
|
||||
csv_file: Union[str, Tuple[str, BytesIO]],
|
||||
description: str,
|
||||
input_keys: List[str],
|
||||
output_keys: List[str],
|
||||
) -> Dataset:
|
||||
"""Upload a CSV file to the LangChain+ API."""
|
||||
files = {"file": csv_file}
|
||||
data = {
|
||||
"input_keys": ",".join(input_keys),
|
||||
"output_keys": ",".join(output_keys),
|
||||
"description": description,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets/upload",
|
||||
headers=self._headers,
|
||||
data=data,
|
||||
files=files,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
# TODO: Make this more robust server-side
|
||||
if "detail" in result and "already exists" in result["detail"]:
|
||||
file_name = csv_file if isinstance(csv_file, str) else csv_file[0]
|
||||
file_name = file_name.split("/")[-1]
|
||||
raise ValueError(f"Dataset {file_name} already exists")
|
||||
return Dataset(**result)
|
||||
|
||||
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
|
||||
"""Create a dataset in the LangChain+ API."""
|
||||
dataset = DatasetCreate(
|
||||
tenant_id=self.tenant_id,
|
||||
name=dataset_name,
|
||||
description=description,
|
||||
)
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets",
|
||||
headers=self._headers,
|
||||
data=dataset.json(),
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return Dataset(**response.json())
|
||||
|
||||
@xor_args(("dataset_name", "dataset_id"))
|
||||
def read_dataset(
|
||||
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
||||
) -> Dataset:
|
||||
path = "/datasets"
|
||||
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
||||
if dataset_id is not None:
|
||||
path += f"/{dataset_id}"
|
||||
elif dataset_name is not None:
|
||||
params["name"] = dataset_name
|
||||
else:
|
||||
raise ValueError("Must provide dataset_name or dataset_id")
|
||||
response = self._get(
|
||||
path,
|
||||
params=params,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
if isinstance(result, list):
|
||||
if len(result) == 0:
|
||||
raise ValueError(f"Dataset {dataset_name} not found")
|
||||
return Dataset(**result[0])
|
||||
return Dataset(**result)
|
||||
|
||||
def list_datasets(self, limit: int = 100) -> Iterable[Dataset]:
|
||||
"""List the datasets on the LangChain+ API."""
|
||||
response = self._get("/datasets", params={"limit": limit})
|
||||
raise_for_status_with_text(response)
|
||||
return [Dataset(**dataset) for dataset in response.json()]
|
||||
|
||||
@xor_args(("dataset_id", "dataset_name"))
|
||||
def delete_dataset(
|
||||
self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""Delete a dataset by ID or name."""
|
||||
if dataset_name is not None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
if dataset_id is None:
|
||||
raise ValueError("Must provide either dataset name or ID")
|
||||
response = requests.delete(
|
||||
f"{self.api_url}/datasets/{dataset_id}",
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return response.json()
|
||||
|
||||
@xor_args(("dataset_id", "dataset_name"))
|
||||
def create_example(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
dataset_id: Optional[UUID] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
outputs: Dict[str, Any] | None = None,
|
||||
) -> Example:
|
||||
"""Create a dataset example in the LangChain+ API."""
|
||||
if dataset_id is None:
|
||||
dataset_id = self.read_dataset(dataset_name).id
|
||||
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"outputs": outputs,
|
||||
"dataset_id": dataset_id,
|
||||
}
|
||||
if created_at:
|
||||
data["created_at"] = created_at.isoformat()
|
||||
example = ExampleCreate(**data)
|
||||
response = requests.post(
|
||||
f"{self.api_url}/examples", headers=self._headers, data=example.json()
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
return Example(**result)
|
||||
|
||||
def read_example(self, example_id: str) -> Example:
|
||||
"""Read an example from the LangChain+ API."""
|
||||
response = self._get(f"/examples/{example_id}")
|
||||
raise_for_status_with_text(response)
|
||||
return Example(**response.json())
|
||||
|
||||
def list_examples(
|
||||
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||
) -> Iterable[Example]:
|
||||
"""List the datasets on the LangChain+ API."""
|
||||
params = {}
|
||||
if dataset_id is not None:
|
||||
params["dataset"] = dataset_id
|
||||
elif dataset_name is not None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
params["dataset"] = dataset_id
|
||||
else:
|
||||
pass
|
||||
response = self._get("/examples", params=params)
|
||||
raise_for_status_with_text(response)
|
||||
return [Example(**dataset) for dataset in response.json()]
|
||||
|
||||
@staticmethod
|
||||
async def _arun_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
if isinstance(llm, BaseLLM):
|
||||
llm_prompts: List[str] = inputs["prompts"]
|
||||
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
|
||||
elif isinstance(llm, BaseChatModel):
|
||||
chat_prompts: List[str] = inputs["prompts"]
|
||||
messages = [
|
||||
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
|
||||
]
|
||||
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||
return llm_output
|
||||
|
||||
@staticmethod
|
||||
async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
n_repetitions: int,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain asynchronously."""
|
||||
previous_example_id = langchain_tracer.example_id
|
||||
langchain_tracer.example_id = example.id
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain, BaseLanguageModel):
|
||||
output: Any = await LangChainPlusClient._arun_llm(
|
||||
llm_or_chain, example.inputs, langchain_tracer
|
||||
)
|
||||
else:
|
||||
output = await llm_or_chain.arun(
|
||||
example.inputs, callbacks=[langchain_tracer]
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
finally:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
async def _gather_with_concurrency(
|
||||
n: int,
|
||||
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracerV2, Dict]]],
|
||||
*async_funcs: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]],
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Run coroutines with a concurrency limit.
|
||||
|
||||
Args:
|
||||
n: The maximum number of concurrent tasks.
|
||||
initializer: A coroutine that initializes shared resources for the tasks.
|
||||
async_funcs: The async_funcs to be run concurrently.
|
||||
|
||||
Returns:
|
||||
A list of results from the coroutines.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
tracer, job_state = await initializer()
|
||||
|
||||
async def run_coroutine_with_semaphore(
|
||||
async_func: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]]
|
||||
) -> Any:
|
||||
async with semaphore:
|
||||
return await async_func(tracer, job_state)
|
||||
|
||||
return await asyncio.gather(
|
||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||
)
|
||||
|
||||
async def _tracer_initializer(
|
||||
self, session_name: str
|
||||
) -> Tuple[LangChainTracerV2, dict]:
|
||||
"""
|
||||
Initialize a tracer to share across tasks.
|
||||
|
||||
Args:
|
||||
session_name: The session name for the tracer.
|
||||
|
||||
Returns:
|
||||
A LangChainTracerV2 instance with an active session.
|
||||
"""
|
||||
job_state = {"num_processed": 0}
|
||||
with tracing_v2_enabled(session_name=session_name) as session:
|
||||
tracer = LangChainTracerV2()
|
||||
tracer.session = session
|
||||
return tracer, job_state
|
||||
|
||||
async def arun_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
concurrency_level: int = 5,
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain: Chain or language model to run over the dataset.
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
session_name: Name of the session to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
if session_name is None:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
session_name = (
|
||||
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
results: Dict[str, List[Any]] = {}
|
||||
|
||||
async def process_example(
|
||||
example: Example, tracer: LangChainTracerV2, job_state: dict
|
||||
) -> None:
|
||||
"""Process a single example."""
|
||||
result = await LangChainPlusClient._arun_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
num_repetitions,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
job_state["num_processed"] += 1
|
||||
if verbose:
|
||||
print(
|
||||
f"Processed examples: {job_state['num_processed']}",
|
||||
end="\r",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
await self._gather_with_concurrency(
|
||||
concurrency_level,
|
||||
functools.partial(self._tracer_initializer, session_name),
|
||||
*(functools.partial(process_example, e) for e in examples),
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def run_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
"""Run the language model on the example."""
|
||||
if isinstance(llm, BaseLLM):
|
||||
llm_prompts: List[str] = inputs["prompts"]
|
||||
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
|
||||
elif isinstance(llm, BaseChatModel):
|
||||
chat_prompts: List[str] = inputs["prompts"]
|
||||
messages = [
|
||||
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
|
||||
]
|
||||
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||
return llm_output
|
||||
|
||||
@staticmethod
|
||||
def run_llm_or_chain(
|
||||
example: Example,
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
n_repetitions: int,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain synchronously."""
|
||||
previous_example_id = langchain_tracer.example_id
|
||||
langchain_tracer.example_id = example.id
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain, BaseLanguageModel):
|
||||
output: Any = LangChainPlusClient.run_llm(
|
||||
llm_or_chain, example.inputs, langchain_tracer
|
||||
)
|
||||
else:
|
||||
output = llm_or_chain.run(
|
||||
example.inputs, callbacks=[langchain_tracer]
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
finally:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
return outputs
|
||||
|
||||
def run_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain: Chain or language model to run over the dataset.
|
||||
concurrency_level: Number of async workers to run in parallel.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
session_name: Name of the session to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
if session_name is None:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
session_name = (
|
||||
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
results: Dict[str, Any] = {}
|
||||
with tracing_v2_enabled(session_name=session_name) as session:
|
||||
tracer = LangChainTracerV2()
|
||||
tracer.session = session
|
||||
|
||||
for i, example in enumerate(examples):
|
||||
result = self.run_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
num_repetitions,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
return results
|
@ -0,0 +1,54 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
class ExampleBase(BaseModel):
|
||||
"""Example base model."""
|
||||
|
||||
dataset_id: UUID
|
||||
inputs: Dict[str, Any]
|
||||
outputs: Optional[Dict[str, Any]] = Field(default=None)
|
||||
|
||||
|
||||
class ExampleCreate(ExampleBase):
|
||||
"""Example create model."""
|
||||
|
||||
id: Optional[UUID]
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Example(ExampleBase):
|
||||
"""Example model."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
runs: List[Run] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasetBase(BaseModel):
|
||||
"""Dataset base model."""
|
||||
|
||||
tenant_id: UUID
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DatasetCreate(DatasetBase):
|
||||
"""Dataset create model."""
|
||||
|
||||
id: Optional[UUID]
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Dataset(DatasetBase):
|
||||
"""Dataset ORM model."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
@ -0,0 +1,42 @@
|
||||
"""Client Utils."""
|
||||
import re
|
||||
from typing import Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
|
||||
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
|
||||
"Human": HumanMessage,
|
||||
"AI": AIMessage,
|
||||
"System": SystemMessage,
|
||||
}
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
input_text: str, roles: Optional[Sequence[str]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""Parse chat messages from a string. This is not robust."""
|
||||
roles = roles or ["Human", "AI", "System"]
|
||||
roles_pattern = "|".join(roles)
|
||||
pattern = (
|
||||
rf"(?P<entity>{roles_pattern}): (?P<message>"
|
||||
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
|
||||
)
|
||||
matches = re.finditer(pattern, input_text, re.MULTILINE)
|
||||
|
||||
results: List[BaseMessage] = []
|
||||
for match in matches:
|
||||
entity = match.group("entity")
|
||||
message = match.group("message").rstrip("\n")
|
||||
if entity in _RESOLUTION_MAP:
|
||||
results.append(_RESOLUTION_MAP[entity](content=message))
|
||||
else:
|
||||
results.append(ChatMessage(role=entity, content=message))
|
||||
|
||||
return results
|
@ -0,0 +1,234 @@
|
||||
"""Test the LangChain+ client."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
_get_link_stem,
|
||||
_is_localhost,
|
||||
)
|
||||
from langchain.client.models import Dataset, Example
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_url, expected_url",
|
||||
[
|
||||
("http://localhost:8000", "http://localhost"),
|
||||
("http://www.example.com", "http://www.example.com"),
|
||||
(
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
),
|
||||
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
|
||||
],
|
||||
)
|
||||
def test_link_split(api_url: str, expected_url: str) -> None:
|
||||
"""Test the link splitting handles both localhost and deployed urls."""
|
||||
assert _get_link_stem(api_url) == expected_url
|
||||
|
||||
|
||||
def test_is_localhost() -> None:
|
||||
assert _is_localhost("http://localhost:8000")
|
||||
assert _is_localhost("http://127.0.0.1:8000")
|
||||
assert _is_localhost("http://0.0.0.0:8000")
|
||||
assert not _is_localhost("http://example.com:8000")
|
||||
|
||||
|
||||
def test_validate_api_key_if_hosted() -> None:
|
||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
return _TENANT_ID
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
|
||||
|
||||
def test_headers() -> None:
|
||||
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
return _TENANT_ID
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"authorization": "Bearer 123"}
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
|
||||
):
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
|
||||
|
||||
@mock.patch("langchain.client.langchain.requests.post")
|
||||
def test_upload_csv(mock_post: mock.Mock) -> None:
|
||||
mock_response = mock.Mock()
|
||||
dataset_id = str(uuid.uuid4())
|
||||
example_1 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
example_2 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
mock_response.json.return_value = {
|
||||
"id": dataset_id,
|
||||
"name": "test.csv",
|
||||
"description": "Test dataset",
|
||||
"owner_id": "the owner",
|
||||
"created_at": _CREATED_AT,
|
||||
"examples": [example_1, example_2],
|
||||
"tenant_id": _TENANT_ID,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
)
|
||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||
|
||||
dataset = client.upload_csv(
|
||||
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
|
||||
)
|
||||
|
||||
assert dataset.id == uuid.UUID(dataset_id)
|
||||
assert dataset.name == "test.csv"
|
||||
assert dataset.description == "Test dataset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = Dataset(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
description="Test dataset",
|
||||
owner_id="owner",
|
||||
created_at=_CREATED_AT,
|
||||
tenant_id=_TENANT_ID,
|
||||
)
|
||||
uuids = [
|
||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||
]
|
||||
examples = [
|
||||
Example(
|
||||
id=uuids[0],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[1],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[2],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "5"},
|
||||
outputs={"output": "6"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[3],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "7"},
|
||||
outputs={"output": "8"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[4],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "9"},
|
||||
outputs={"output": "10"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
]
|
||||
|
||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
||||
return examples
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
tracer: Any,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
def mock_load_session(
|
||||
self: Any, name: str, *args: Any, **kwargs: Any
|
||||
) -> TracerSessionV2:
|
||||
return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch.object(
|
||||
LangChainTracerV2, "load_session", new=mock_load_session
|
||||
):
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
)
|
||||
chain = mock.MagicMock()
|
||||
num_repetitions = 3
|
||||
results = await client.arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain=chain,
|
||||
concurrency_level=2,
|
||||
session_name="test_session",
|
||||
num_repetitions=num_repetitions,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||
for _ in range(num_repetitions)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results == expected
|
@ -0,0 +1,70 @@
|
||||
"""Test LangChain+ Client Utils."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain.client.utils import parse_chat_messages
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages() -> None:
|
||||
"""Test that chat messages are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="I am human roar"),
|
||||
AIMessage(content="I am AI beep boop"),
|
||||
SystemMessage(content="I am a system message"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_input() -> None:
|
||||
"""Test that an empty input string returns an empty list."""
|
||||
input_text = ""
|
||||
expected: List[BaseMessage] = []
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiline_messages() -> None:
|
||||
"""Test that multiline messages are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
|
||||
" beep boop\nSystem: I am a system\nand a message"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="I am a human\nand I roar"),
|
||||
AIMessage(content="I am an AI\nand I beep boop"),
|
||||
SystemMessage(content="I am a system\nand a message"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_custom_roles() -> None:
|
||||
"""Test that custom roles are parsed correctly."""
|
||||
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
|
||||
expected = [
|
||||
ChatMessage(role="Client", content="I need help"),
|
||||
ChatMessage(role="Agent", content="I'm here to help"),
|
||||
ChatMessage(role="Client", content="Thank you"),
|
||||
]
|
||||
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_embedded_roles() -> None:
|
||||
"""Test that messages with embedded role references are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: Oh ai what if you said AI: foo bar?"
|
||||
"\nAI: Well, that would be interesting!"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
|
||||
AIMessage(content="Well, that would be interesting!"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
Loading…
Reference in New Issue