mirror of https://github.com/hwchase17/langchain
Use client from LCP-SDK (#5695)
- Remove the client implementation (this breaks backwards compatibility for existing testers. I could keep the stub in that file if we want, but not many people are using it yet - Add SDK as dependency - Update the 'run_on_dataset' method to be a function that optionally accepts a client as an argument - Remove the langchain plus server implementation (you get it for free with the SDK now) We could make the SDK optional for now, but the plan is to use w/in the tracer so it would likely become a hard dependency at some point.pull/5789/head^2
parent
08e2352f7b
commit
204a73c1d9
@ -1,16 +0,0 @@
|
||||
server {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
error_log /var/log/nginx/error.log warn;
|
||||
|
||||
location / {
|
||||
root /usr/share/nginx/html;
|
||||
index index.html index.htm;
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
error_page 500 502 503 504 /50x.html;
|
||||
location = /50x.html {
|
||||
root /usr/share/nginx/html;
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
version: '3'
|
||||
services:
|
||||
ngrok:
|
||||
image: ngrok/ngrok:latest
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- "start"
|
||||
- "--all"
|
||||
- "--config"
|
||||
- "/etc/ngrok.yml"
|
||||
volumes:
|
||||
- ./ngrok_config.yaml:/etc/ngrok.yml
|
||||
ports:
|
||||
- 4040:4040
|
||||
langchain-backend:
|
||||
depends_on:
|
||||
- ngrok
|
@ -1,49 +0,0 @@
|
||||
version: '3'
|
||||
services:
|
||||
langchain-frontend:
|
||||
image: langchain/${_LANGCHAINPLUS_IMAGE_PREFIX-}langchainplus-frontend:latest
|
||||
ports:
|
||||
- 80:80
|
||||
environment:
|
||||
- REACT_APP_BACKEND_URL=http://localhost:1984
|
||||
depends_on:
|
||||
- langchain-backend
|
||||
volumes:
|
||||
- ./conf/nginx.conf:/etc/nginx/default.conf:ro
|
||||
build:
|
||||
context: frontend-react/.
|
||||
dockerfile: Dockerfile
|
||||
langchain-backend:
|
||||
image: langchain/${_LANGCHAINPLUS_IMAGE_PREFIX-}langchainplus-backend:latest
|
||||
environment:
|
||||
- PORT=1984
|
||||
- LANGCHAIN_ENV=local_docker
|
||||
- LOG_LEVEL=warning
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
ports:
|
||||
- 1984:1984
|
||||
depends_on:
|
||||
- langchain-db
|
||||
build:
|
||||
context: backend/.
|
||||
dockerfile: Dockerfile
|
||||
langchain-db:
|
||||
image: postgres:14.1
|
||||
command:
|
||||
[
|
||||
"postgres",
|
||||
"-c",
|
||||
"log_min_messages=WARNING",
|
||||
"-c",
|
||||
"client_min_messages=WARNING"
|
||||
]
|
||||
environment:
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_DB=postgres
|
||||
volumes:
|
||||
- langchain-db-data:/var/lib/postgresql/data
|
||||
ports:
|
||||
- 5433:5432
|
||||
volumes:
|
||||
langchain-db-data:
|
@ -1,377 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from subprocess import CalledProcessError
|
||||
from typing import Dict, Generator, List, Mapping, Optional, Union, cast
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from langchain.env import get_runtime_environment
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def pprint_services(services_status: List[Mapping[str, Union[str, List[str]]]]) -> None:
|
||||
# Loop through and collect Service, State, and Publishers["PublishedPorts"]
|
||||
# for each service
|
||||
services = []
|
||||
for service in services_status:
|
||||
service_status: Dict[str, str] = {
|
||||
"Service": str(service["Service"]),
|
||||
"Status": str(service["Status"]),
|
||||
}
|
||||
publishers = cast(List[Dict], service.get("Publishers", []))
|
||||
if publishers:
|
||||
service_status["PublishedPorts"] = ", ".join(
|
||||
[str(publisher["PublishedPort"]) for publisher in publishers]
|
||||
)
|
||||
services.append(service_status)
|
||||
|
||||
max_service_len = max(len(service["Service"]) for service in services)
|
||||
max_state_len = max(len(service["Status"]) for service in services)
|
||||
service_message = [
|
||||
"\n"
|
||||
+ "Service".ljust(max_service_len + 2)
|
||||
+ "Status".ljust(max_state_len + 2)
|
||||
+ "Published Ports"
|
||||
]
|
||||
for service in services:
|
||||
service_str = service["Service"].ljust(max_service_len + 2)
|
||||
state_str = service["Status"].ljust(max_state_len + 2)
|
||||
ports_str = service.get("PublishedPorts", "")
|
||||
service_message.append(service_str + state_str + ports_str)
|
||||
|
||||
langchain_endpoint: str = "http://localhost:1984"
|
||||
used_ngrok = any(["ngrok" in service["Service"] for service in services])
|
||||
if used_ngrok:
|
||||
langchain_endpoint = get_ngrok_url(auth_token=None)
|
||||
|
||||
service_message.append(
|
||||
"\nTo connect, set the following environment variables"
|
||||
" in your LangChain application:"
|
||||
"\nLANGCHAIN_TRACING_V2=true"
|
||||
f"\nLANGCHAIN_ENDPOINT={langchain_endpoint}"
|
||||
)
|
||||
logger.info("\n".join(service_message))
|
||||
|
||||
|
||||
def get_docker_compose_command() -> List[str]:
|
||||
"""Get the correct docker compose command for this system."""
|
||||
try:
|
||||
subprocess.check_call(
|
||||
["docker", "compose", "--version"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return ["docker", "compose"]
|
||||
except (CalledProcessError, FileNotFoundError):
|
||||
try:
|
||||
subprocess.check_call(
|
||||
["docker-compose", "--version"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return ["docker-compose"]
|
||||
except (CalledProcessError, FileNotFoundError):
|
||||
raise ValueError(
|
||||
"Neither 'docker compose' nor 'docker-compose'"
|
||||
" commands are available. Please install the Docker"
|
||||
" server following the instructions for your operating"
|
||||
" system at https://docs.docker.com/engine/install/"
|
||||
)
|
||||
|
||||
|
||||
def get_ngrok_url(auth_token: Optional[str]) -> str:
|
||||
"""Get the ngrok URL for the LangChainPlus server."""
|
||||
ngrok_url = "http://localhost:4040/api/tunnels"
|
||||
try:
|
||||
response = requests.get(ngrok_url)
|
||||
response.raise_for_status()
|
||||
exposed_url = response.json()["tunnels"][0]["public_url"]
|
||||
except requests.exceptions.HTTPError:
|
||||
raise ValueError("Could not connect to ngrok console.")
|
||||
except (KeyError, IndexError):
|
||||
message = "ngrok failed to start correctly. "
|
||||
if auth_token is not None:
|
||||
message += "Please check that your authtoken is correct."
|
||||
raise ValueError(message)
|
||||
return exposed_url
|
||||
|
||||
|
||||
@contextmanager
|
||||
def create_ngrok_config(
|
||||
auth_token: Optional[str] = None,
|
||||
) -> Generator[Path, None, None]:
|
||||
"""Create the ngrok configuration file."""
|
||||
config_path = _DIR / "ngrok_config.yaml"
|
||||
if config_path.exists():
|
||||
# If there was an error in a prior run, it's possible
|
||||
# Docker made this a directory instead of a file
|
||||
if config_path.is_dir():
|
||||
shutil.rmtree(config_path)
|
||||
else:
|
||||
config_path.unlink()
|
||||
ngrok_config = {
|
||||
"tunnels": {
|
||||
"langchain": {
|
||||
"proto": "http",
|
||||
"addr": "langchain-backend:8000",
|
||||
}
|
||||
},
|
||||
"version": "2",
|
||||
"region": "us",
|
||||
}
|
||||
if auth_token is not None:
|
||||
ngrok_config["authtoken"] = auth_token
|
||||
config_path = _DIR / "ngrok_config.yaml"
|
||||
with config_path.open("w") as f:
|
||||
yaml.dump(ngrok_config, f)
|
||||
yield config_path
|
||||
# Delete the config file after use
|
||||
config_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
class PlusCommand:
|
||||
"""Manage the LangChainPlus Tracing server."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.docker_compose_command = get_docker_compose_command()
|
||||
self.docker_compose_file = (
|
||||
Path(__file__).absolute().parent / "docker-compose.yaml"
|
||||
)
|
||||
self.ngrok_path = Path(__file__).absolute().parent / "docker-compose.ngrok.yaml"
|
||||
|
||||
def _open_browser(self, url: str) -> None:
|
||||
try:
|
||||
subprocess.run(["open", url])
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _start_local(self) -> None:
|
||||
command = [
|
||||
*self.docker_compose_command,
|
||||
"-f",
|
||||
str(self.docker_compose_file),
|
||||
]
|
||||
subprocess.run(
|
||||
[
|
||||
*command,
|
||||
"up",
|
||||
"--pull=always",
|
||||
"--quiet-pull",
|
||||
"--wait",
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
"langchain plus server is running at http://localhost. To connect"
|
||||
" locally, set the following environment variable"
|
||||
" when running your LangChain application."
|
||||
)
|
||||
|
||||
logger.info("\tLANGCHAIN_TRACING_V2=true")
|
||||
self._open_browser("http://localhost")
|
||||
|
||||
def _start_and_expose(self, auth_token: Optional[str]) -> None:
|
||||
with create_ngrok_config(auth_token=auth_token):
|
||||
command = [
|
||||
*self.docker_compose_command,
|
||||
"-f",
|
||||
str(self.docker_compose_file),
|
||||
"-f",
|
||||
str(self.ngrok_path),
|
||||
]
|
||||
subprocess.run(
|
||||
[
|
||||
*command,
|
||||
"up",
|
||||
"--pull=always",
|
||||
"--quiet-pull",
|
||||
"--wait",
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
"ngrok is running. You can view the dashboard at http://0.0.0.0:4040"
|
||||
)
|
||||
ngrok_url = get_ngrok_url(auth_token)
|
||||
logger.info(
|
||||
"langchain plus server is running at http://localhost."
|
||||
" To connect remotely, set the following environment"
|
||||
" variable when running your LangChain application."
|
||||
)
|
||||
logger.info("\tLANGCHAIN_TRACING_V2=true")
|
||||
logger.info(f"\tLANGCHAIN_ENDPOINT={ngrok_url}")
|
||||
self._open_browser("http://0.0.0.0:4040")
|
||||
self._open_browser("http://localhost")
|
||||
|
||||
def start(
|
||||
self,
|
||||
*,
|
||||
expose: bool = False,
|
||||
auth_token: Optional[str] = None,
|
||||
dev: bool = False,
|
||||
openai_api_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Run the LangChainPlus server locally.
|
||||
|
||||
Args:
|
||||
expose: If True, expose the server to the internet using ngrok.
|
||||
auth_token: The ngrok authtoken to use (visible in the ngrok dashboard).
|
||||
If not provided, ngrok server session length will be restricted.
|
||||
dev: If True, use the development (rc) image of LangChainPlus.
|
||||
openai_api_key: The OpenAI API key to use for LangChainPlus
|
||||
If not provided, the OpenAI API Key will be read from the
|
||||
OPENAI_API_KEY environment variable. If neither are provided,
|
||||
some features of LangChainPlus will not be available.
|
||||
"""
|
||||
if dev:
|
||||
os.environ["_LANGCHAINPLUS_IMAGE_PREFIX"] = "rc-"
|
||||
if openai_api_key is not None:
|
||||
os.environ["OPENAI_API_KEY"] = openai_api_key
|
||||
if expose:
|
||||
self._start_and_expose(auth_token=auth_token)
|
||||
else:
|
||||
self._start_local()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the LangChainPlus server."""
|
||||
subprocess.run(
|
||||
[
|
||||
*self.docker_compose_command,
|
||||
"-f",
|
||||
str(self.docker_compose_file),
|
||||
"-f",
|
||||
str(self.ngrok_path),
|
||||
"down",
|
||||
]
|
||||
)
|
||||
|
||||
def logs(self) -> None:
|
||||
"""Print the logs from the LangChainPlus server."""
|
||||
subprocess.run(
|
||||
[
|
||||
*self.docker_compose_command,
|
||||
"-f",
|
||||
str(self.docker_compose_file),
|
||||
"-f",
|
||||
str(self.ngrok_path),
|
||||
"logs",
|
||||
]
|
||||
)
|
||||
|
||||
def status(self) -> None:
|
||||
"""Provide information about the status LangChainPlus server."""
|
||||
|
||||
command = [
|
||||
*self.docker_compose_command,
|
||||
"-f",
|
||||
str(self.docker_compose_file),
|
||||
"ps",
|
||||
"--format",
|
||||
"json",
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
command_stdout = result.stdout.decode("utf-8")
|
||||
services_status = json.loads(command_stdout)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error checking LangChainPlus server status.")
|
||||
return
|
||||
if services_status:
|
||||
logger.info("The LangChainPlus server is currently running.")
|
||||
pprint_services(services_status)
|
||||
else:
|
||||
logger.info("The LangChainPlus server is not running.")
|
||||
return
|
||||
|
||||
|
||||
def env() -> None:
|
||||
"""Print the runtime environment information."""
|
||||
env = get_runtime_environment()
|
||||
logger.info("LangChain Environment:")
|
||||
logger.info("\n".join(f"{k}:{v}" for k, v in env.items()))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entrypoint for the CLI."""
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(description="LangChainPlus CLI commands")
|
||||
|
||||
server_command = PlusCommand()
|
||||
server_parser = subparsers.add_parser("plus", description=server_command.__doc__)
|
||||
server_subparsers = server_parser.add_subparsers()
|
||||
|
||||
server_start_parser = server_subparsers.add_parser(
|
||||
"start", description="Start the LangChainPlus server."
|
||||
)
|
||||
server_start_parser.add_argument(
|
||||
"--expose",
|
||||
action="store_true",
|
||||
help="Expose the server to the internet using ngrok.",
|
||||
)
|
||||
server_start_parser.add_argument(
|
||||
"--ngrok-authtoken",
|
||||
default=os.getenv("NGROK_AUTHTOKEN"),
|
||||
help="The ngrok authtoken to use (visible in the ngrok dashboard)."
|
||||
" If not provided, ngrok server session length will be restricted.",
|
||||
)
|
||||
server_start_parser.add_argument(
|
||||
"--dev",
|
||||
action="store_true",
|
||||
help="Use the development version of the LangChainPlus image.",
|
||||
)
|
||||
server_start_parser.add_argument(
|
||||
"--openai-api-key",
|
||||
default=os.getenv("OPENAI_API_KEY"),
|
||||
help="The OpenAI API key to use for LangChainPlus."
|
||||
" If not provided, the OpenAI API Key will be read from the"
|
||||
" OPENAI_API_KEY environment variable. If neither are provided,"
|
||||
" some features of LangChainPlus will not be available.",
|
||||
)
|
||||
server_start_parser.set_defaults(
|
||||
func=lambda args: server_command.start(
|
||||
expose=args.expose,
|
||||
auth_token=args.ngrok_authtoken,
|
||||
dev=args.dev,
|
||||
openai_api_key=args.openai_api_key,
|
||||
)
|
||||
)
|
||||
|
||||
server_stop_parser = server_subparsers.add_parser(
|
||||
"stop", description="Stop the LangChainPlus server."
|
||||
)
|
||||
server_stop_parser.set_defaults(func=lambda args: server_command.stop())
|
||||
|
||||
server_logs_parser = server_subparsers.add_parser(
|
||||
"logs", description="Show the LangChainPlus server logs."
|
||||
)
|
||||
server_logs_parser.set_defaults(func=lambda args: server_command.logs())
|
||||
server_status_parser = server_subparsers.add_parser(
|
||||
"status", description="Show the LangChainPlus server status."
|
||||
)
|
||||
server_status_parser.set_defaults(func=lambda args: server_command.status())
|
||||
env_parser = subparsers.add_parser("env")
|
||||
env_parser.set_defaults(func=lambda args: env())
|
||||
|
||||
args = parser.parse_args()
|
||||
if not hasattr(args, "func"):
|
||||
parser.print_help()
|
||||
return
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,6 +1,9 @@
|
||||
"""LangChain+ Client."""
|
||||
from langchain.client.runner_utils import (
|
||||
arun_on_dataset,
|
||||
arun_on_examples,
|
||||
run_on_dataset,
|
||||
run_on_examples,
|
||||
)
|
||||
|
||||
|
||||
from langchain.client.langchain import LangChainPlusClient
|
||||
|
||||
__all__ = ["LangChainPlusClient"]
|
||||
__all__ = ["arun_on_dataset", "run_on_dataset", "arun_on_examples", "run_on_examples"]
|
||||
|
@ -1,562 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlsplit
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from pydantic import BaseSettings, Field, root_validator
|
||||
from requests import Response
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.schemas import Run as TracerRun
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.models import (
|
||||
APIFeedbackSource,
|
||||
Dataset,
|
||||
DatasetCreate,
|
||||
Example,
|
||||
ExampleCreate,
|
||||
ExampleUpdate,
|
||||
Feedback,
|
||||
FeedbackCreate,
|
||||
FeedbackSourceBase,
|
||||
FeedbackSourceType,
|
||||
ListFeedbackQueryParams,
|
||||
ListRunsQueryParams,
|
||||
ModelFeedbackSource,
|
||||
)
|
||||
from langchain.client.runner_utils import arun_on_examples, run_on_examples
|
||||
from langchain.utils import raise_for_status_with_text, xor_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||
|
||||
|
||||
class Run(TracerRun):
|
||||
id: UUID
|
||||
|
||||
|
||||
def _get_link_stem(url: str) -> str:
|
||||
scheme = urlsplit(url).scheme
|
||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||
return f"{scheme}://{netloc_prefix}"
|
||||
|
||||
|
||||
def _is_localhost(url: str) -> bool:
|
||||
"""Check if the URL is localhost."""
|
||||
try:
|
||||
netloc = urlsplit(url).netloc.split(":")[0]
|
||||
ip = socket.gethostbyname(netloc)
|
||||
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
|
||||
except socket.gaierror:
|
||||
return False
|
||||
|
||||
|
||||
class LangChainPlusClient(BaseSettings):
|
||||
"""Client for interacting with the LangChain+ API."""
|
||||
|
||||
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
||||
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Verify API key is provided if url not localhost."""
|
||||
api_url: str = values.get("api_url", "http://localhost:1984")
|
||||
api_key: Optional[str] = values.get("api_key")
|
||||
if not _is_localhost(api_url):
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key must be provided when using hosted LangChain+ API"
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def _get_session_name(
|
||||
session_name: Optional[str],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
dataset_name: str,
|
||||
) -> str:
|
||||
if session_name is not None:
|
||||
return session_name
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
model_name = llm_or_chain_factory.__class__.__name__
|
||||
else:
|
||||
model_name = llm_or_chain_factory().__class__.__name__
|
||||
return f"{dataset_name}-{model_name}-{current_time}"
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the instance with a link to the URL."""
|
||||
if _is_localhost(self.api_url):
|
||||
link = "http://localhost"
|
||||
elif "dev" in self.api_url:
|
||||
link = "https://dev.langchain.plus"
|
||||
else:
|
||||
link = "https://www.langchain.plus"
|
||||
return f'<a href="{link}", target="_blank" rel="noopener">LangChain+ Client</a>'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the instance with a link to the URL."""
|
||||
return f"LangChainPlusClient (API URL: {self.api_url})"
|
||||
|
||||
@property
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
"""Get the headers for the API request."""
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["x-api-key"] = self.api_key
|
||||
return headers
|
||||
|
||||
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
||||
"""Make a GET request."""
|
||||
return requests.get(
|
||||
f"{self.api_url}{path}", headers=self._headers, params=params
|
||||
)
|
||||
|
||||
def upload_dataframe(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
name: str,
|
||||
description: str,
|
||||
input_keys: Sequence[str],
|
||||
output_keys: Sequence[str],
|
||||
) -> Dataset:
|
||||
"""Upload a dataframe as individual examples to the LangChain+ API."""
|
||||
dataset = self.create_dataset(dataset_name=name, description=description)
|
||||
for row in df.itertuples():
|
||||
inputs = {key: getattr(row, key) for key in input_keys}
|
||||
outputs = {key: getattr(row, key) for key in output_keys}
|
||||
self.create_example(inputs, outputs=outputs, dataset_id=dataset.id)
|
||||
return dataset
|
||||
|
||||
def upload_csv(
|
||||
self,
|
||||
csv_file: Union[str, Tuple[str, BytesIO]],
|
||||
description: str,
|
||||
input_keys: Sequence[str],
|
||||
output_keys: Sequence[str],
|
||||
) -> Dataset:
|
||||
"""Upload a CSV file to the LangChain+ API."""
|
||||
files = {"file": csv_file}
|
||||
data = {
|
||||
"input_keys": ",".join(input_keys),
|
||||
"output_keys": ",".join(output_keys),
|
||||
"description": description,
|
||||
}
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets/upload",
|
||||
headers=self._headers,
|
||||
data=data,
|
||||
files=files,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
# TODO: Make this more robust server-side
|
||||
if "detail" in result and "already exists" in result["detail"]:
|
||||
file_name = csv_file if isinstance(csv_file, str) else csv_file[0]
|
||||
file_name = file_name.split("/")[-1]
|
||||
raise ValueError(f"Dataset {file_name} already exists")
|
||||
return Dataset(**result)
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def read_run(self, run_id: Union[str, UUID]) -> Run:
|
||||
"""Read a run from the LangChain+ API."""
|
||||
response = self._get(f"/runs/{run_id}")
|
||||
raise_for_status_with_text(response)
|
||||
return Run(**response.json())
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def list_runs(
|
||||
self,
|
||||
*,
|
||||
session_id: Optional[str] = None,
|
||||
session_name: Optional[str] = None,
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Run]:
|
||||
"""List runs from the LangChain+ API."""
|
||||
if session_name is not None:
|
||||
if session_id is not None:
|
||||
raise ValueError("Only one of session_id or session_name may be given")
|
||||
session_id = self.read_session(session_name=session_name).id
|
||||
query_params = ListRunsQueryParams(
|
||||
session_id=session_id, run_type=run_type, **kwargs
|
||||
)
|
||||
response = self._get("/runs", params=query_params.dict(exclude_none=True))
|
||||
raise_for_status_with_text(response)
|
||||
yield from [Run(**run) for run in response.json()]
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
@xor_args(("session_id", "session_name"))
|
||||
def read_session(
|
||||
self, *, session_id: Optional[str] = None, session_name: Optional[str] = None
|
||||
) -> TracerSession:
|
||||
"""Read a session from the LangChain+ API."""
|
||||
path = "/sessions"
|
||||
params: Dict[str, Any] = {"limit": 1}
|
||||
if session_id is not None:
|
||||
path += f"/{session_id}"
|
||||
elif session_name is not None:
|
||||
params["name"] = session_name
|
||||
else:
|
||||
raise ValueError("Must provide dataset_name or dataset_id")
|
||||
response = self._get(
|
||||
path,
|
||||
params=params,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
if isinstance(result, list):
|
||||
if len(result) == 0:
|
||||
raise ValueError(f"Dataset {session_name} not found")
|
||||
return TracerSession(**result[0])
|
||||
return TracerSession(**response.json())
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def list_sessions(self) -> Iterator[TracerSession]:
|
||||
"""List sessions from the LangChain+ API."""
|
||||
response = self._get("/sessions")
|
||||
raise_for_status_with_text(response)
|
||||
yield from [TracerSession(**session) for session in response.json()]
|
||||
|
||||
@xor_args(("session_name", "session_id"))
|
||||
def delete_session(
|
||||
self, *, session_name: Optional[str] = None, session_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Delete a session from the LangChain+ API."""
|
||||
if session_name is not None:
|
||||
session_id = self.read_session(session_name=session_name).id
|
||||
elif session_id is None:
|
||||
raise ValueError("Must provide session_name or session_id")
|
||||
response = requests.delete(
|
||||
self.api_url + f"/sessions/{session_id}",
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return None
|
||||
|
||||
def create_dataset(
|
||||
self, dataset_name: str, *, description: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""Create a dataset in the LangChain+ API."""
|
||||
dataset = DatasetCreate(
|
||||
name=dataset_name,
|
||||
description=description,
|
||||
)
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets",
|
||||
headers=self._headers,
|
||||
data=dataset.json(),
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return Dataset(**response.json())
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
@xor_args(("dataset_name", "dataset_id"))
|
||||
def read_dataset(
|
||||
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
||||
) -> Dataset:
|
||||
path = "/datasets"
|
||||
params: Dict[str, Any] = {"limit": 1}
|
||||
if dataset_id is not None:
|
||||
path += f"/{dataset_id}"
|
||||
elif dataset_name is not None:
|
||||
params["name"] = dataset_name
|
||||
else:
|
||||
raise ValueError("Must provide dataset_name or dataset_id")
|
||||
response = self._get(
|
||||
path,
|
||||
params=params,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
if isinstance(result, list):
|
||||
if len(result) == 0:
|
||||
raise ValueError(f"Dataset {dataset_name} not found")
|
||||
return Dataset(**result[0])
|
||||
return Dataset(**result)
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def list_datasets(self, limit: int = 100) -> Iterator[Dataset]:
|
||||
"""List the datasets on the LangChain+ API."""
|
||||
response = self._get("/datasets", params={"limit": limit})
|
||||
raise_for_status_with_text(response)
|
||||
yield from [Dataset(**dataset) for dataset in response.json()]
|
||||
|
||||
@xor_args(("dataset_id", "dataset_name"))
|
||||
def delete_dataset(
|
||||
self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""Delete a dataset by ID or name."""
|
||||
if dataset_name is not None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
if dataset_id is None:
|
||||
raise ValueError("Must provide either dataset name or ID")
|
||||
response = requests.delete(
|
||||
f"{self.api_url}/datasets/{dataset_id}",
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return Dataset(**response.json())
|
||||
|
||||
@xor_args(("dataset_id", "dataset_name"))
|
||||
def create_example(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
dataset_id: Optional[UUID] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
outputs: Dict[str, Any] | None = None,
|
||||
) -> Example:
|
||||
"""Create a dataset example in the LangChain+ API."""
|
||||
if dataset_id is None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"outputs": outputs,
|
||||
"dataset_id": dataset_id,
|
||||
}
|
||||
if created_at:
|
||||
data["created_at"] = created_at.isoformat()
|
||||
example = ExampleCreate(**data)
|
||||
response = requests.post(
|
||||
f"{self.api_url}/examples", headers=self._headers, data=example.json()
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
return Example(**result)
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def read_example(self, example_id: Union[str, UUID]) -> Example:
|
||||
"""Read an example from the LangChain+ API."""
|
||||
response = self._get(f"/examples/{example_id}")
|
||||
raise_for_status_with_text(response)
|
||||
return Example(**response.json())
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def list_examples(
|
||||
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||
) -> Iterator[Example]:
|
||||
"""List the datasets on the LangChain+ API."""
|
||||
params = {}
|
||||
if dataset_id is not None:
|
||||
params["dataset"] = dataset_id
|
||||
elif dataset_name is not None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
params["dataset"] = dataset_id
|
||||
else:
|
||||
pass
|
||||
response = self._get("/examples", params=params)
|
||||
raise_for_status_with_text(response)
|
||||
yield from [Example(**dataset) for dataset in response.json()]
|
||||
|
||||
def update_example(
|
||||
self,
|
||||
example_id: str,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Update a specific example."""
|
||||
example = ExampleUpdate(
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
response = requests.patch(
|
||||
f"{self.api_url}/examples/{example_id}",
|
||||
headers=self._headers,
|
||||
data=example.json(exclude_none=True),
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return response.json()
|
||||
|
||||
def create_feedback(
|
||||
self,
|
||||
run_id: str,
|
||||
key: str,
|
||||
*,
|
||||
score: Union[float, int, bool, None] = None,
|
||||
value: Union[float, int, bool, str, dict, None] = None,
|
||||
correction: Union[str, dict, None] = None,
|
||||
comment: Union[str, None] = None,
|
||||
source_info: Optional[Dict[str, Any]] = None,
|
||||
feedback_source_type: FeedbackSourceType = FeedbackSourceType.API,
|
||||
) -> Feedback:
|
||||
"""Create a feedback in the LangChain+ API.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run to provide feedback on.
|
||||
key: The name of the metric, tag, or 'aspect' this
|
||||
feedback is about.
|
||||
score: The score to rate this run on the metric
|
||||
or aspect.
|
||||
value: The display value or non-numeric value for this feedback.
|
||||
correction: The proper ground truth for this run.
|
||||
comment: A comment about this feedback.
|
||||
source_info: Information about the source of this feedback.
|
||||
feedback_source_type: The type of feedback source.
|
||||
"""
|
||||
if feedback_source_type == FeedbackSourceType.API:
|
||||
feedback_source: FeedbackSourceBase = APIFeedbackSource(
|
||||
metadata=source_info
|
||||
)
|
||||
elif feedback_source_type == FeedbackSourceType.MODEL:
|
||||
feedback_source = ModelFeedbackSource(metadata=source_info)
|
||||
else:
|
||||
raise ValueError(f"Unknown feedback source type {feedback_source_type}")
|
||||
feedback = FeedbackCreate(
|
||||
run_id=run_id,
|
||||
key=key,
|
||||
score=score,
|
||||
value=value,
|
||||
correction=correction,
|
||||
comment=comment,
|
||||
feedback_source=feedback_source,
|
||||
)
|
||||
response = requests.post(
|
||||
self.api_url + "/feedback",
|
||||
headers=self._headers,
|
||||
data=feedback.json(),
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return Feedback(**feedback.dict())
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def read_feedback(self, feedback_id: str) -> Feedback:
|
||||
"""Read a feedback from the LangChain+ API."""
|
||||
response = self._get(f"/feedback/{feedback_id}")
|
||||
raise_for_status_with_text(response)
|
||||
return Feedback(**response.json())
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
def list_feedback(
|
||||
self,
|
||||
*,
|
||||
run_ids: Optional[Sequence[Union[str, UUID]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Feedback]:
|
||||
"""List the feedback objects on the LangChain+ API."""
|
||||
params = ListFeedbackQueryParams(
|
||||
run=run_ids,
|
||||
**kwargs,
|
||||
)
|
||||
response = self._get("/feedback", params=params.dict(exclude_none=True))
|
||||
raise_for_status_with_text(response)
|
||||
yield from [Feedback(**feedback) for feedback in response.json()]
|
||||
|
||||
def delete_feedback(self, feedback_id: str) -> None:
|
||||
"""Delete a feedback by ID."""
|
||||
response = requests.delete(
|
||||
f"{self.api_url}/feedback/{feedback_id}",
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
|
||||
async def arun_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
concurrency_level: int = 5,
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
session_name: Name of the session to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
session_name = LangChainPlusClient._get_session_name(
|
||||
session_name, llm_or_chain_factory, dataset_name
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
|
||||
return await arun_on_examples(
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
concurrency_level=concurrency_level,
|
||||
num_repetitions=num_repetitions,
|
||||
session_name=session_name,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def run_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
concurrency_level: Number of async workers to run in parallel.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
session_name: Name of the session to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
session_name = LangChainPlusClient._get_session_name(
|
||||
session_name, llm_or_chain_factory, dataset_name
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
return run_on_examples(
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions=num_repetitions,
|
||||
session_name=session_name,
|
||||
verbose=verbose,
|
||||
)
|
@ -1,206 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
|
||||
|
||||
class ExampleBase(BaseModel):
|
||||
"""Example base model."""
|
||||
|
||||
dataset_id: UUID
|
||||
inputs: Dict[str, Any]
|
||||
outputs: Optional[Dict[str, Any]] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class ExampleCreate(ExampleBase):
|
||||
"""Example create model."""
|
||||
|
||||
id: Optional[UUID]
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Example(ExampleBase):
|
||||
"""Example model."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
runs: List[Run] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExampleUpdate(BaseModel):
|
||||
"""Update class for Example."""
|
||||
|
||||
dataset_id: Optional[UUID] = None
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
outputs: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class DatasetBase(BaseModel):
|
||||
"""Dataset base model."""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class DatasetCreate(DatasetBase):
|
||||
"""Dataset create model."""
|
||||
|
||||
id: Optional[UUID]
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Dataset(DatasetBase):
|
||||
"""Dataset ORM model."""
|
||||
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
|
||||
class ListRunsQueryParams(BaseModel):
|
||||
"""Query params for GET /runs endpoint."""
|
||||
|
||||
id: Optional[List[UUID]]
|
||||
"""Filter runs by id."""
|
||||
parent_run: Optional[UUID]
|
||||
"""Filter runs by parent run."""
|
||||
run_type: Optional[RunTypeEnum]
|
||||
"""Filter runs by type."""
|
||||
session: Optional[UUID] = Field(default=None, alias="session_id")
|
||||
"""Only return runs within a session."""
|
||||
reference_example: Optional[UUID]
|
||||
"""Only return runs that reference the specified dataset example."""
|
||||
execution_order: Optional[int]
|
||||
"""Filter runs by execution order."""
|
||||
error: Optional[bool]
|
||||
"""Whether to return only runs that errored."""
|
||||
offset: Optional[int]
|
||||
"""The offset of the first run to return."""
|
||||
limit: Optional[int]
|
||||
"""The maximum number of runs to return."""
|
||||
start_time: Optional[datetime] = Field(
|
||||
default=None,
|
||||
alias="start_before",
|
||||
description="Query Runs that started <= this time",
|
||||
)
|
||||
end_time: Optional[datetime] = Field(
|
||||
default=None,
|
||||
alias="end_after",
|
||||
description="Query Runs that ended >= this time",
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
frozen = True
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_time_range(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that start_time <= end_time."""
|
||||
start_time = values.get("start_time")
|
||||
end_time = values.get("end_time")
|
||||
if start_time and end_time and start_time > end_time:
|
||||
raise ValueError("start_time must be <= end_time")
|
||||
return values
|
||||
|
||||
|
||||
class FeedbackSourceBase(BaseModel):
|
||||
type: ClassVar[str]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class APIFeedbackSource(FeedbackSourceBase):
|
||||
"""API feedback source."""
|
||||
|
||||
type: ClassVar[str] = "api"
|
||||
|
||||
|
||||
class ModelFeedbackSource(FeedbackSourceBase):
|
||||
"""Model feedback source."""
|
||||
|
||||
type: ClassVar[str] = "model"
|
||||
|
||||
|
||||
class FeedbackSourceType(Enum):
|
||||
"""Feedback source type."""
|
||||
|
||||
API = "api"
|
||||
"""General feedback submitted from the API."""
|
||||
MODEL = "model"
|
||||
"""Model-assisted feedback."""
|
||||
|
||||
|
||||
class FeedbackBase(BaseModel):
|
||||
"""Feedback schema."""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
"""The time the feedback was created."""
|
||||
modified_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
"""The time the feedback was last modified."""
|
||||
run_id: UUID
|
||||
"""The associated run ID this feedback is logged for."""
|
||||
key: str
|
||||
"""The metric name, tag, or aspect to provide feedback on."""
|
||||
score: Union[float, int, bool, None] = None
|
||||
"""Value or score to assign the run."""
|
||||
value: Union[float, int, bool, str, dict, None] = None
|
||||
"""The display value, tag or other value for the feedback if not a metric."""
|
||||
comment: Optional[str] = None
|
||||
"""Comment or explanation for the feedback."""
|
||||
correction: Union[str, dict, None] = None
|
||||
"""Correction for the run."""
|
||||
feedback_source: Optional[
|
||||
Union[APIFeedbackSource, ModelFeedbackSource, Mapping[str, Any]]
|
||||
] = None
|
||||
"""The source of the feedback."""
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class FeedbackCreate(FeedbackBase):
|
||||
"""Schema used for creating feedback."""
|
||||
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
|
||||
feedback_source: APIFeedbackSource
|
||||
"""The source of the feedback."""
|
||||
|
||||
|
||||
class Feedback(FeedbackBase):
|
||||
"""Schema for getting feedback."""
|
||||
|
||||
id: UUID
|
||||
feedback_source: Optional[Dict] = None
|
||||
"""The source of the feedback. In this case"""
|
||||
|
||||
|
||||
class ListFeedbackQueryParams(BaseModel):
|
||||
"""Query Params for listing feedbacks."""
|
||||
|
||||
run: Optional[Sequence[UUID]] = None
|
||||
limit: int = 100
|
||||
offset: int = 0
|
||||
|
||||
class Config:
|
||||
"""Config for query params."""
|
||||
|
||||
extra = "forbid"
|
||||
frozen = True
|
File diff suppressed because it is too large
Load Diff
@ -1,116 +0,0 @@
|
||||
"""LangChain+ langchain_client Integration Tests."""
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from tenacity import RetryError
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.client import LangChainPlusClient
|
||||
from langchain.tools.base import tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langchain_client(monkeypatch: pytest.MonkeyPatch) -> LangChainPlusClient:
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
return LangChainPlusClient()
|
||||
|
||||
|
||||
def test_sessions(
|
||||
langchain_client: LangChainPlusClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Test sessions."""
|
||||
session_names = set([session.name for session in langchain_client.list_sessions()])
|
||||
new_session = f"Session {uuid4()}"
|
||||
assert new_session not in session_names
|
||||
|
||||
@tool
|
||||
def example_tool() -> str:
|
||||
"""Call me, maybe."""
|
||||
return "test_tool"
|
||||
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
with tracing_v2_enabled(session_name=new_session):
|
||||
example_tool({})
|
||||
session = langchain_client.read_session(session_name=new_session)
|
||||
assert session.name == new_session
|
||||
session_names = set([sess.name for sess in langchain_client.list_sessions()])
|
||||
assert new_session in session_names
|
||||
runs = list(langchain_client.list_runs(session_name=new_session))
|
||||
session_id_runs = list(langchain_client.list_runs(session_id=session.id))
|
||||
assert len(runs) == len(session_id_runs) == 1
|
||||
assert runs[0].id == session_id_runs[0].id
|
||||
langchain_client.delete_session(session_name=new_session)
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
langchain_client.read_session(session_name=new_session)
|
||||
assert new_session not in set(
|
||||
[sess.name for sess in langchain_client.list_sessions()]
|
||||
)
|
||||
with pytest.raises(RetryError):
|
||||
langchain_client.delete_session(session_name=new_session)
|
||||
with pytest.raises(RetryError):
|
||||
langchain_client.read_run(run_id=str(runs[0].id))
|
||||
|
||||
|
||||
def test_feedback_cycle(
|
||||
monkeypatch: pytest.MonkeyPatch, langchain_client: LangChainPlusClient
|
||||
) -> None:
|
||||
"""Test that feedback is correctly created and updated."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
|
||||
monkeypatch.setenv("LANGCHAIN_SESSION", f"Feedback Testing {uuid4()}")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False
|
||||
)
|
||||
|
||||
agent.run(
|
||||
"What is the population of Kuala Lumpur as of January, 2023?"
|
||||
" What is it's square root?"
|
||||
)
|
||||
other_session_name = f"Feedback Testing {uuid4()}"
|
||||
with tracing_v2_enabled(session_name=other_session_name):
|
||||
try:
|
||||
agent.run("What is the square root of 3?")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
runs = list(
|
||||
langchain_client.list_runs(
|
||||
session_name=os.environ["LANGCHAIN_SESSION"], error=False, execution_order=1
|
||||
)
|
||||
)
|
||||
assert len(runs) == 1
|
||||
order_2 = list(
|
||||
langchain_client.list_runs(
|
||||
session_name=os.environ["LANGCHAIN_SESSION"], execution_order=2
|
||||
)
|
||||
)
|
||||
assert len(order_2) > 0
|
||||
langchain_client.create_feedback(str(order_2[0].id), "test score", score=0)
|
||||
feedback = langchain_client.create_feedback(str(runs[0].id), "test score", score=1)
|
||||
feedbacks = list(langchain_client.list_feedback(run_ids=[str(runs[0].id)]))
|
||||
assert len(feedbacks) == 1
|
||||
assert feedbacks[0].id == feedback.id
|
||||
|
||||
# Add feedback to other session
|
||||
other_runs = list(
|
||||
langchain_client.list_runs(session_name=other_session_name, execution_order=1)
|
||||
)
|
||||
assert len(other_runs) == 1
|
||||
langchain_client.create_feedback(
|
||||
run_id=str(other_runs[0].id), key="test score", score=0
|
||||
)
|
||||
all_runs = list(
|
||||
langchain_client.list_runs(session_name=os.environ["LANGCHAIN_SESSION"])
|
||||
) + list(langchain_client.list_runs(session_name=other_session_name))
|
||||
test_run_ids = [str(run.id) for run in all_runs]
|
||||
all_feedback = list(langchain_client.list_feedback(run_ids=test_run_ids))
|
||||
assert len(all_feedback) == 3
|
||||
for feedback in all_feedback:
|
||||
langchain_client.delete_feedback(str(feedback.id))
|
||||
feedbacks = list(langchain_client.list_feedback(run_ids=test_run_ids))
|
||||
assert len(feedbacks) == 0
|
@ -1,207 +0,0 @@
|
||||
"""Test the LangChain+ client."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
_get_link_stem,
|
||||
_is_localhost,
|
||||
)
|
||||
from langchain.client.models import Dataset, Example
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_url, expected_url",
|
||||
[
|
||||
("http://localhost:8000", "http://localhost"),
|
||||
("http://www.example.com", "http://www.example.com"),
|
||||
(
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
),
|
||||
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
|
||||
],
|
||||
)
|
||||
def test_link_split(api_url: str, expected_url: str) -> None:
|
||||
"""Test the link splitting handles both localhost and deployed urls."""
|
||||
assert _get_link_stem(api_url) == expected_url
|
||||
|
||||
|
||||
def test_is_localhost() -> None:
|
||||
assert _is_localhost("http://localhost:8000")
|
||||
assert _is_localhost("http://127.0.0.1:8000")
|
||||
assert _is_localhost("http://0.0.0.0:8000")
|
||||
assert not _is_localhost("http://example.com:8000")
|
||||
|
||||
|
||||
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
|
||||
|
||||
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"x-api-key": "123"}
|
||||
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
|
||||
|
||||
@mock.patch("langchain.client.langchain.requests.post")
|
||||
def test_upload_csv(mock_post: mock.Mock) -> None:
|
||||
mock_response = mock.Mock()
|
||||
dataset_id = str(uuid.uuid4())
|
||||
example_1 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
example_2 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
mock_response.json.return_value = {
|
||||
"id": dataset_id,
|
||||
"name": "test.csv",
|
||||
"description": "Test dataset",
|
||||
"owner_id": "the owner",
|
||||
"created_at": _CREATED_AT,
|
||||
"examples": [example_1, example_2],
|
||||
"tenant_id": _TENANT_ID,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000",
|
||||
api_key="123",
|
||||
)
|
||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||
|
||||
dataset = client.upload_csv(
|
||||
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
|
||||
)
|
||||
|
||||
assert dataset.id == uuid.UUID(dataset_id)
|
||||
assert dataset.name == "test.csv"
|
||||
assert dataset.description == "Test dataset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = Dataset(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
description="Test dataset",
|
||||
owner_id="owner",
|
||||
created_at=_CREATED_AT,
|
||||
tenant_id=_TENANT_ID,
|
||||
)
|
||||
uuids = [
|
||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||
]
|
||||
examples = [
|
||||
Example(
|
||||
id=uuids[0],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[1],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[2],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "5"},
|
||||
outputs={"output": "6"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[3],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "7"},
|
||||
outputs={"output": "8"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[4],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "9"},
|
||||
outputs={"output": "10"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
]
|
||||
|
||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
||||
return examples
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tracer: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch(
|
||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||
):
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
num_repetitions = 3
|
||||
results = await client.arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
concurrency_level=2,
|
||||
session_name="test_session",
|
||||
num_repetitions=num_repetitions,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||
for _ in range(num_repetitions)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results == expected
|
Loading…
Reference in New Issue