mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Use client from LCP-SDK (#5695)
- Remove the client implementation (this breaks backwards compatibility for existing testers. I could keep the stub in that file if we want, but not many people are using it yet - Add SDK as dependency - Update the 'run_on_dataset' method to be a function that optionally accepts a client as an argument - Remove the langchain plus server implementation (you get it for free with the SDK now) We could make the SDK optional for now, but the plan is to use w/in the tracer so it would likely become a hard dependency at some point.
This commit is contained in:
parent
08e2352f7b
commit
204a73c1d9
@ -1,16 +0,0 @@
|
|||||||
server {
|
|
||||||
listen 80;
|
|
||||||
server_name localhost;
|
|
||||||
error_log /var/log/nginx/error.log warn;
|
|
||||||
|
|
||||||
location / {
|
|
||||||
root /usr/share/nginx/html;
|
|
||||||
index index.html index.htm;
|
|
||||||
try_files $uri $uri/ /index.html;
|
|
||||||
}
|
|
||||||
|
|
||||||
error_page 500 502 503 504 /50x.html;
|
|
||||||
location = /50x.html {
|
|
||||||
root /usr/share/nginx/html;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
version: '3'
|
|
||||||
services:
|
|
||||||
ngrok:
|
|
||||||
image: ngrok/ngrok:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
command:
|
|
||||||
- "start"
|
|
||||||
- "--all"
|
|
||||||
- "--config"
|
|
||||||
- "/etc/ngrok.yml"
|
|
||||||
volumes:
|
|
||||||
- ./ngrok_config.yaml:/etc/ngrok.yml
|
|
||||||
ports:
|
|
||||||
- 4040:4040
|
|
||||||
langchain-backend:
|
|
||||||
depends_on:
|
|
||||||
- ngrok
|
|
@ -1,49 +0,0 @@
|
|||||||
version: '3'
|
|
||||||
services:
|
|
||||||
langchain-frontend:
|
|
||||||
image: langchain/${_LANGCHAINPLUS_IMAGE_PREFIX-}langchainplus-frontend:latest
|
|
||||||
ports:
|
|
||||||
- 80:80
|
|
||||||
environment:
|
|
||||||
- REACT_APP_BACKEND_URL=http://localhost:1984
|
|
||||||
depends_on:
|
|
||||||
- langchain-backend
|
|
||||||
volumes:
|
|
||||||
- ./conf/nginx.conf:/etc/nginx/default.conf:ro
|
|
||||||
build:
|
|
||||||
context: frontend-react/.
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
langchain-backend:
|
|
||||||
image: langchain/${_LANGCHAINPLUS_IMAGE_PREFIX-}langchainplus-backend:latest
|
|
||||||
environment:
|
|
||||||
- PORT=1984
|
|
||||||
- LANGCHAIN_ENV=local_docker
|
|
||||||
- LOG_LEVEL=warning
|
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
|
||||||
ports:
|
|
||||||
- 1984:1984
|
|
||||||
depends_on:
|
|
||||||
- langchain-db
|
|
||||||
build:
|
|
||||||
context: backend/.
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
langchain-db:
|
|
||||||
image: postgres:14.1
|
|
||||||
command:
|
|
||||||
[
|
|
||||||
"postgres",
|
|
||||||
"-c",
|
|
||||||
"log_min_messages=WARNING",
|
|
||||||
"-c",
|
|
||||||
"client_min_messages=WARNING"
|
|
||||||
]
|
|
||||||
environment:
|
|
||||||
- POSTGRES_PASSWORD=postgres
|
|
||||||
- POSTGRES_USER=postgres
|
|
||||||
- POSTGRES_DB=postgres
|
|
||||||
volumes:
|
|
||||||
- langchain-db-data:/var/lib/postgresql/data
|
|
||||||
ports:
|
|
||||||
- 5433:5432
|
|
||||||
volumes:
|
|
||||||
langchain-db-data:
|
|
@ -1,377 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from subprocess import CalledProcessError
|
|
||||||
from typing import Dict, Generator, List, Mapping, Optional, Union, cast
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from langchain.env import get_runtime_environment
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_DIR = Path(__file__).parent
|
|
||||||
|
|
||||||
|
|
||||||
def pprint_services(services_status: List[Mapping[str, Union[str, List[str]]]]) -> None:
|
|
||||||
# Loop through and collect Service, State, and Publishers["PublishedPorts"]
|
|
||||||
# for each service
|
|
||||||
services = []
|
|
||||||
for service in services_status:
|
|
||||||
service_status: Dict[str, str] = {
|
|
||||||
"Service": str(service["Service"]),
|
|
||||||
"Status": str(service["Status"]),
|
|
||||||
}
|
|
||||||
publishers = cast(List[Dict], service.get("Publishers", []))
|
|
||||||
if publishers:
|
|
||||||
service_status["PublishedPorts"] = ", ".join(
|
|
||||||
[str(publisher["PublishedPort"]) for publisher in publishers]
|
|
||||||
)
|
|
||||||
services.append(service_status)
|
|
||||||
|
|
||||||
max_service_len = max(len(service["Service"]) for service in services)
|
|
||||||
max_state_len = max(len(service["Status"]) for service in services)
|
|
||||||
service_message = [
|
|
||||||
"\n"
|
|
||||||
+ "Service".ljust(max_service_len + 2)
|
|
||||||
+ "Status".ljust(max_state_len + 2)
|
|
||||||
+ "Published Ports"
|
|
||||||
]
|
|
||||||
for service in services:
|
|
||||||
service_str = service["Service"].ljust(max_service_len + 2)
|
|
||||||
state_str = service["Status"].ljust(max_state_len + 2)
|
|
||||||
ports_str = service.get("PublishedPorts", "")
|
|
||||||
service_message.append(service_str + state_str + ports_str)
|
|
||||||
|
|
||||||
langchain_endpoint: str = "http://localhost:1984"
|
|
||||||
used_ngrok = any(["ngrok" in service["Service"] for service in services])
|
|
||||||
if used_ngrok:
|
|
||||||
langchain_endpoint = get_ngrok_url(auth_token=None)
|
|
||||||
|
|
||||||
service_message.append(
|
|
||||||
"\nTo connect, set the following environment variables"
|
|
||||||
" in your LangChain application:"
|
|
||||||
"\nLANGCHAIN_TRACING_V2=true"
|
|
||||||
f"\nLANGCHAIN_ENDPOINT={langchain_endpoint}"
|
|
||||||
)
|
|
||||||
logger.info("\n".join(service_message))
|
|
||||||
|
|
||||||
|
|
||||||
def get_docker_compose_command() -> List[str]:
|
|
||||||
"""Get the correct docker compose command for this system."""
|
|
||||||
try:
|
|
||||||
subprocess.check_call(
|
|
||||||
["docker", "compose", "--version"],
|
|
||||||
stdout=subprocess.DEVNULL,
|
|
||||||
stderr=subprocess.DEVNULL,
|
|
||||||
)
|
|
||||||
return ["docker", "compose"]
|
|
||||||
except (CalledProcessError, FileNotFoundError):
|
|
||||||
try:
|
|
||||||
subprocess.check_call(
|
|
||||||
["docker-compose", "--version"],
|
|
||||||
stdout=subprocess.DEVNULL,
|
|
||||||
stderr=subprocess.DEVNULL,
|
|
||||||
)
|
|
||||||
return ["docker-compose"]
|
|
||||||
except (CalledProcessError, FileNotFoundError):
|
|
||||||
raise ValueError(
|
|
||||||
"Neither 'docker compose' nor 'docker-compose'"
|
|
||||||
" commands are available. Please install the Docker"
|
|
||||||
" server following the instructions for your operating"
|
|
||||||
" system at https://docs.docker.com/engine/install/"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ngrok_url(auth_token: Optional[str]) -> str:
|
|
||||||
"""Get the ngrok URL for the LangChainPlus server."""
|
|
||||||
ngrok_url = "http://localhost:4040/api/tunnels"
|
|
||||||
try:
|
|
||||||
response = requests.get(ngrok_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
exposed_url = response.json()["tunnels"][0]["public_url"]
|
|
||||||
except requests.exceptions.HTTPError:
|
|
||||||
raise ValueError("Could not connect to ngrok console.")
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
message = "ngrok failed to start correctly. "
|
|
||||||
if auth_token is not None:
|
|
||||||
message += "Please check that your authtoken is correct."
|
|
||||||
raise ValueError(message)
|
|
||||||
return exposed_url
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def create_ngrok_config(
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
) -> Generator[Path, None, None]:
|
|
||||||
"""Create the ngrok configuration file."""
|
|
||||||
config_path = _DIR / "ngrok_config.yaml"
|
|
||||||
if config_path.exists():
|
|
||||||
# If there was an error in a prior run, it's possible
|
|
||||||
# Docker made this a directory instead of a file
|
|
||||||
if config_path.is_dir():
|
|
||||||
shutil.rmtree(config_path)
|
|
||||||
else:
|
|
||||||
config_path.unlink()
|
|
||||||
ngrok_config = {
|
|
||||||
"tunnels": {
|
|
||||||
"langchain": {
|
|
||||||
"proto": "http",
|
|
||||||
"addr": "langchain-backend:8000",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"version": "2",
|
|
||||||
"region": "us",
|
|
||||||
}
|
|
||||||
if auth_token is not None:
|
|
||||||
ngrok_config["authtoken"] = auth_token
|
|
||||||
config_path = _DIR / "ngrok_config.yaml"
|
|
||||||
with config_path.open("w") as f:
|
|
||||||
yaml.dump(ngrok_config, f)
|
|
||||||
yield config_path
|
|
||||||
# Delete the config file after use
|
|
||||||
config_path.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
class PlusCommand:
|
|
||||||
"""Manage the LangChainPlus Tracing server."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.docker_compose_command = get_docker_compose_command()
|
|
||||||
self.docker_compose_file = (
|
|
||||||
Path(__file__).absolute().parent / "docker-compose.yaml"
|
|
||||||
)
|
|
||||||
self.ngrok_path = Path(__file__).absolute().parent / "docker-compose.ngrok.yaml"
|
|
||||||
|
|
||||||
def _open_browser(self, url: str) -> None:
|
|
||||||
try:
|
|
||||||
subprocess.run(["open", url])
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _start_local(self) -> None:
|
|
||||||
command = [
|
|
||||||
*self.docker_compose_command,
|
|
||||||
"-f",
|
|
||||||
str(self.docker_compose_file),
|
|
||||||
]
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
*command,
|
|
||||||
"up",
|
|
||||||
"--pull=always",
|
|
||||||
"--quiet-pull",
|
|
||||||
"--wait",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"langchain plus server is running at http://localhost. To connect"
|
|
||||||
" locally, set the following environment variable"
|
|
||||||
" when running your LangChain application."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("\tLANGCHAIN_TRACING_V2=true")
|
|
||||||
self._open_browser("http://localhost")
|
|
||||||
|
|
||||||
def _start_and_expose(self, auth_token: Optional[str]) -> None:
|
|
||||||
with create_ngrok_config(auth_token=auth_token):
|
|
||||||
command = [
|
|
||||||
*self.docker_compose_command,
|
|
||||||
"-f",
|
|
||||||
str(self.docker_compose_file),
|
|
||||||
"-f",
|
|
||||||
str(self.ngrok_path),
|
|
||||||
]
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
*command,
|
|
||||||
"up",
|
|
||||||
"--pull=always",
|
|
||||||
"--quiet-pull",
|
|
||||||
"--wait",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"ngrok is running. You can view the dashboard at http://0.0.0.0:4040"
|
|
||||||
)
|
|
||||||
ngrok_url = get_ngrok_url(auth_token)
|
|
||||||
logger.info(
|
|
||||||
"langchain plus server is running at http://localhost."
|
|
||||||
" To connect remotely, set the following environment"
|
|
||||||
" variable when running your LangChain application."
|
|
||||||
)
|
|
||||||
logger.info("\tLANGCHAIN_TRACING_V2=true")
|
|
||||||
logger.info(f"\tLANGCHAIN_ENDPOINT={ngrok_url}")
|
|
||||||
self._open_browser("http://0.0.0.0:4040")
|
|
||||||
self._open_browser("http://localhost")
|
|
||||||
|
|
||||||
def start(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
expose: bool = False,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
dev: bool = False,
|
|
||||||
openai_api_key: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Run the LangChainPlus server locally.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expose: If True, expose the server to the internet using ngrok.
|
|
||||||
auth_token: The ngrok authtoken to use (visible in the ngrok dashboard).
|
|
||||||
If not provided, ngrok server session length will be restricted.
|
|
||||||
dev: If True, use the development (rc) image of LangChainPlus.
|
|
||||||
openai_api_key: The OpenAI API key to use for LangChainPlus
|
|
||||||
If not provided, the OpenAI API Key will be read from the
|
|
||||||
OPENAI_API_KEY environment variable. If neither are provided,
|
|
||||||
some features of LangChainPlus will not be available.
|
|
||||||
"""
|
|
||||||
if dev:
|
|
||||||
os.environ["_LANGCHAINPLUS_IMAGE_PREFIX"] = "rc-"
|
|
||||||
if openai_api_key is not None:
|
|
||||||
os.environ["OPENAI_API_KEY"] = openai_api_key
|
|
||||||
if expose:
|
|
||||||
self._start_and_expose(auth_token=auth_token)
|
|
||||||
else:
|
|
||||||
self._start_local()
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the LangChainPlus server."""
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
*self.docker_compose_command,
|
|
||||||
"-f",
|
|
||||||
str(self.docker_compose_file),
|
|
||||||
"-f",
|
|
||||||
str(self.ngrok_path),
|
|
||||||
"down",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def logs(self) -> None:
|
|
||||||
"""Print the logs from the LangChainPlus server."""
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
*self.docker_compose_command,
|
|
||||||
"-f",
|
|
||||||
str(self.docker_compose_file),
|
|
||||||
"-f",
|
|
||||||
str(self.ngrok_path),
|
|
||||||
"logs",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def status(self) -> None:
|
|
||||||
"""Provide information about the status LangChainPlus server."""
|
|
||||||
|
|
||||||
command = [
|
|
||||||
*self.docker_compose_command,
|
|
||||||
"-f",
|
|
||||||
str(self.docker_compose_file),
|
|
||||||
"ps",
|
|
||||||
"--format",
|
|
||||||
"json",
|
|
||||||
]
|
|
||||||
|
|
||||||
result = subprocess.run(
|
|
||||||
command,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
command_stdout = result.stdout.decode("utf-8")
|
|
||||||
services_status = json.loads(command_stdout)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error("Error checking LangChainPlus server status.")
|
|
||||||
return
|
|
||||||
if services_status:
|
|
||||||
logger.info("The LangChainPlus server is currently running.")
|
|
||||||
pprint_services(services_status)
|
|
||||||
else:
|
|
||||||
logger.info("The LangChainPlus server is not running.")
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def env() -> None:
|
|
||||||
"""Print the runtime environment information."""
|
|
||||||
env = get_runtime_environment()
|
|
||||||
logger.info("LangChain Environment:")
|
|
||||||
logger.info("\n".join(f"{k}:{v}" for k, v in env.items()))
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
"""Main entrypoint for the CLI."""
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
subparsers = parser.add_subparsers(description="LangChainPlus CLI commands")
|
|
||||||
|
|
||||||
server_command = PlusCommand()
|
|
||||||
server_parser = subparsers.add_parser("plus", description=server_command.__doc__)
|
|
||||||
server_subparsers = server_parser.add_subparsers()
|
|
||||||
|
|
||||||
server_start_parser = server_subparsers.add_parser(
|
|
||||||
"start", description="Start the LangChainPlus server."
|
|
||||||
)
|
|
||||||
server_start_parser.add_argument(
|
|
||||||
"--expose",
|
|
||||||
action="store_true",
|
|
||||||
help="Expose the server to the internet using ngrok.",
|
|
||||||
)
|
|
||||||
server_start_parser.add_argument(
|
|
||||||
"--ngrok-authtoken",
|
|
||||||
default=os.getenv("NGROK_AUTHTOKEN"),
|
|
||||||
help="The ngrok authtoken to use (visible in the ngrok dashboard)."
|
|
||||||
" If not provided, ngrok server session length will be restricted.",
|
|
||||||
)
|
|
||||||
server_start_parser.add_argument(
|
|
||||||
"--dev",
|
|
||||||
action="store_true",
|
|
||||||
help="Use the development version of the LangChainPlus image.",
|
|
||||||
)
|
|
||||||
server_start_parser.add_argument(
|
|
||||||
"--openai-api-key",
|
|
||||||
default=os.getenv("OPENAI_API_KEY"),
|
|
||||||
help="The OpenAI API key to use for LangChainPlus."
|
|
||||||
" If not provided, the OpenAI API Key will be read from the"
|
|
||||||
" OPENAI_API_KEY environment variable. If neither are provided,"
|
|
||||||
" some features of LangChainPlus will not be available.",
|
|
||||||
)
|
|
||||||
server_start_parser.set_defaults(
|
|
||||||
func=lambda args: server_command.start(
|
|
||||||
expose=args.expose,
|
|
||||||
auth_token=args.ngrok_authtoken,
|
|
||||||
dev=args.dev,
|
|
||||||
openai_api_key=args.openai_api_key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
server_stop_parser = server_subparsers.add_parser(
|
|
||||||
"stop", description="Stop the LangChainPlus server."
|
|
||||||
)
|
|
||||||
server_stop_parser.set_defaults(func=lambda args: server_command.stop())
|
|
||||||
|
|
||||||
server_logs_parser = server_subparsers.add_parser(
|
|
||||||
"logs", description="Show the LangChainPlus server logs."
|
|
||||||
)
|
|
||||||
server_logs_parser.set_defaults(func=lambda args: server_command.logs())
|
|
||||||
server_status_parser = server_subparsers.add_parser(
|
|
||||||
"status", description="Show the LangChainPlus server status."
|
|
||||||
)
|
|
||||||
server_status_parser.set_defaults(func=lambda args: server_command.status())
|
|
||||||
env_parser = subparsers.add_parser("env")
|
|
||||||
env_parser.set_defaults(func=lambda args: env())
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
if not hasattr(args, "func"):
|
|
||||||
parser.print_help()
|
|
||||||
return
|
|
||||||
args.func(args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,6 +1,9 @@
|
|||||||
"""LangChain+ Client."""
|
"""LangChain+ Client."""
|
||||||
|
from langchain.client.runner_utils import (
|
||||||
|
arun_on_dataset,
|
||||||
|
arun_on_examples,
|
||||||
|
run_on_dataset,
|
||||||
|
run_on_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["arun_on_dataset", "run_on_dataset", "arun_on_examples", "run_on_examples"]
|
||||||
from langchain.client.langchain import LangChainPlusClient
|
|
||||||
|
|
||||||
__all__ = ["LangChainPlusClient"]
|
|
||||||
|
@ -1,562 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import socket
|
|
||||||
from datetime import datetime
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
from urllib.parse import urlsplit
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from pydantic import BaseSettings, Field, root_validator
|
|
||||||
from requests import Response
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
|
||||||
from langchain.callbacks.tracers.schemas import Run as TracerRun
|
|
||||||
from langchain.callbacks.tracers.schemas import TracerSession
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.client.models import (
|
|
||||||
APIFeedbackSource,
|
|
||||||
Dataset,
|
|
||||||
DatasetCreate,
|
|
||||||
Example,
|
|
||||||
ExampleCreate,
|
|
||||||
ExampleUpdate,
|
|
||||||
Feedback,
|
|
||||||
FeedbackCreate,
|
|
||||||
FeedbackSourceBase,
|
|
||||||
FeedbackSourceType,
|
|
||||||
ListFeedbackQueryParams,
|
|
||||||
ListRunsQueryParams,
|
|
||||||
ModelFeedbackSource,
|
|
||||||
)
|
|
||||||
from langchain.client.runner_utils import arun_on_examples, run_on_examples
|
|
||||||
from langchain.utils import raise_for_status_with_text, xor_args
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
|
||||||
|
|
||||||
|
|
||||||
class Run(TracerRun):
|
|
||||||
id: UUID
|
|
||||||
|
|
||||||
|
|
||||||
def _get_link_stem(url: str) -> str:
|
|
||||||
scheme = urlsplit(url).scheme
|
|
||||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
|
||||||
return f"{scheme}://{netloc_prefix}"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_localhost(url: str) -> bool:
|
|
||||||
"""Check if the URL is localhost."""
|
|
||||||
try:
|
|
||||||
netloc = urlsplit(url).netloc.split(":")[0]
|
|
||||||
ip = socket.gethostbyname(netloc)
|
|
||||||
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
|
|
||||||
except socket.gaierror:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class LangChainPlusClient(BaseSettings):
|
|
||||||
"""Client for interacting with the LangChain+ API."""
|
|
||||||
|
|
||||||
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
|
||||||
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Verify API key is provided if url not localhost."""
|
|
||||||
api_url: str = values.get("api_url", "http://localhost:1984")
|
|
||||||
api_key: Optional[str] = values.get("api_key")
|
|
||||||
if not _is_localhost(api_url):
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"API key must be provided when using hosted LangChain+ API"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_session_name(
|
|
||||||
session_name: Optional[str],
|
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
||||||
dataset_name: str,
|
|
||||||
) -> str:
|
|
||||||
if session_name is not None:
|
|
||||||
return session_name
|
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
||||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
|
||||||
model_name = llm_or_chain_factory.__class__.__name__
|
|
||||||
else:
|
|
||||||
model_name = llm_or_chain_factory().__class__.__name__
|
|
||||||
return f"{dataset_name}-{model_name}-{current_time}"
|
|
||||||
|
|
||||||
def _repr_html_(self) -> str:
|
|
||||||
"""Return an HTML representation of the instance with a link to the URL."""
|
|
||||||
if _is_localhost(self.api_url):
|
|
||||||
link = "http://localhost"
|
|
||||||
elif "dev" in self.api_url:
|
|
||||||
link = "https://dev.langchain.plus"
|
|
||||||
else:
|
|
||||||
link = "https://www.langchain.plus"
|
|
||||||
return f'<a href="{link}", target="_blank" rel="noopener">LangChain+ Client</a>'
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
"""Return a string representation of the instance with a link to the URL."""
|
|
||||||
return f"LangChainPlusClient (API URL: {self.api_url})"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _headers(self) -> Dict[str, str]:
|
|
||||||
"""Get the headers for the API request."""
|
|
||||||
headers = {}
|
|
||||||
if self.api_key:
|
|
||||||
headers["x-api-key"] = self.api_key
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
|
||||||
"""Make a GET request."""
|
|
||||||
return requests.get(
|
|
||||||
f"{self.api_url}{path}", headers=self._headers, params=params
|
|
||||||
)
|
|
||||||
|
|
||||||
def upload_dataframe(
|
|
||||||
self,
|
|
||||||
df: pd.DataFrame,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
input_keys: Sequence[str],
|
|
||||||
output_keys: Sequence[str],
|
|
||||||
) -> Dataset:
|
|
||||||
"""Upload a dataframe as individual examples to the LangChain+ API."""
|
|
||||||
dataset = self.create_dataset(dataset_name=name, description=description)
|
|
||||||
for row in df.itertuples():
|
|
||||||
inputs = {key: getattr(row, key) for key in input_keys}
|
|
||||||
outputs = {key: getattr(row, key) for key in output_keys}
|
|
||||||
self.create_example(inputs, outputs=outputs, dataset_id=dataset.id)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def upload_csv(
|
|
||||||
self,
|
|
||||||
csv_file: Union[str, Tuple[str, BytesIO]],
|
|
||||||
description: str,
|
|
||||||
input_keys: Sequence[str],
|
|
||||||
output_keys: Sequence[str],
|
|
||||||
) -> Dataset:
|
|
||||||
"""Upload a CSV file to the LangChain+ API."""
|
|
||||||
files = {"file": csv_file}
|
|
||||||
data = {
|
|
||||||
"input_keys": ",".join(input_keys),
|
|
||||||
"output_keys": ",".join(output_keys),
|
|
||||||
"description": description,
|
|
||||||
}
|
|
||||||
response = requests.post(
|
|
||||||
self.api_url + "/datasets/upload",
|
|
||||||
headers=self._headers,
|
|
||||||
data=data,
|
|
||||||
files=files,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
result = response.json()
|
|
||||||
# TODO: Make this more robust server-side
|
|
||||||
if "detail" in result and "already exists" in result["detail"]:
|
|
||||||
file_name = csv_file if isinstance(csv_file, str) else csv_file[0]
|
|
||||||
file_name = file_name.split("/")[-1]
|
|
||||||
raise ValueError(f"Dataset {file_name} already exists")
|
|
||||||
return Dataset(**result)
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def read_run(self, run_id: Union[str, UUID]) -> Run:
|
|
||||||
"""Read a run from the LangChain+ API."""
|
|
||||||
response = self._get(f"/runs/{run_id}")
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return Run(**response.json())
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def list_runs(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
session_name: Optional[str] = None,
|
|
||||||
run_type: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[Run]:
|
|
||||||
"""List runs from the LangChain+ API."""
|
|
||||||
if session_name is not None:
|
|
||||||
if session_id is not None:
|
|
||||||
raise ValueError("Only one of session_id or session_name may be given")
|
|
||||||
session_id = self.read_session(session_name=session_name).id
|
|
||||||
query_params = ListRunsQueryParams(
|
|
||||||
session_id=session_id, run_type=run_type, **kwargs
|
|
||||||
)
|
|
||||||
response = self._get("/runs", params=query_params.dict(exclude_none=True))
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
yield from [Run(**run) for run in response.json()]
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
@xor_args(("session_id", "session_name"))
|
|
||||||
def read_session(
|
|
||||||
self, *, session_id: Optional[str] = None, session_name: Optional[str] = None
|
|
||||||
) -> TracerSession:
|
|
||||||
"""Read a session from the LangChain+ API."""
|
|
||||||
path = "/sessions"
|
|
||||||
params: Dict[str, Any] = {"limit": 1}
|
|
||||||
if session_id is not None:
|
|
||||||
path += f"/{session_id}"
|
|
||||||
elif session_name is not None:
|
|
||||||
params["name"] = session_name
|
|
||||||
else:
|
|
||||||
raise ValueError("Must provide dataset_name or dataset_id")
|
|
||||||
response = self._get(
|
|
||||||
path,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
result = response.json()
|
|
||||||
if isinstance(result, list):
|
|
||||||
if len(result) == 0:
|
|
||||||
raise ValueError(f"Dataset {session_name} not found")
|
|
||||||
return TracerSession(**result[0])
|
|
||||||
return TracerSession(**response.json())
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def list_sessions(self) -> Iterator[TracerSession]:
|
|
||||||
"""List sessions from the LangChain+ API."""
|
|
||||||
response = self._get("/sessions")
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
yield from [TracerSession(**session) for session in response.json()]
|
|
||||||
|
|
||||||
@xor_args(("session_name", "session_id"))
|
|
||||||
def delete_session(
|
|
||||||
self, *, session_name: Optional[str] = None, session_id: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""Delete a session from the LangChain+ API."""
|
|
||||||
if session_name is not None:
|
|
||||||
session_id = self.read_session(session_name=session_name).id
|
|
||||||
elif session_id is None:
|
|
||||||
raise ValueError("Must provide session_name or session_id")
|
|
||||||
response = requests.delete(
|
|
||||||
self.api_url + f"/sessions/{session_id}",
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def create_dataset(
|
|
||||||
self, dataset_name: str, *, description: Optional[str] = None
|
|
||||||
) -> Dataset:
|
|
||||||
"""Create a dataset in the LangChain+ API."""
|
|
||||||
dataset = DatasetCreate(
|
|
||||||
name=dataset_name,
|
|
||||||
description=description,
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
self.api_url + "/datasets",
|
|
||||||
headers=self._headers,
|
|
||||||
data=dataset.json(),
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return Dataset(**response.json())
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
@xor_args(("dataset_name", "dataset_id"))
|
|
||||||
def read_dataset(
|
|
||||||
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
|
||||||
) -> Dataset:
|
|
||||||
path = "/datasets"
|
|
||||||
params: Dict[str, Any] = {"limit": 1}
|
|
||||||
if dataset_id is not None:
|
|
||||||
path += f"/{dataset_id}"
|
|
||||||
elif dataset_name is not None:
|
|
||||||
params["name"] = dataset_name
|
|
||||||
else:
|
|
||||||
raise ValueError("Must provide dataset_name or dataset_id")
|
|
||||||
response = self._get(
|
|
||||||
path,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
result = response.json()
|
|
||||||
if isinstance(result, list):
|
|
||||||
if len(result) == 0:
|
|
||||||
raise ValueError(f"Dataset {dataset_name} not found")
|
|
||||||
return Dataset(**result[0])
|
|
||||||
return Dataset(**result)
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def list_datasets(self, limit: int = 100) -> Iterator[Dataset]:
|
|
||||||
"""List the datasets on the LangChain+ API."""
|
|
||||||
response = self._get("/datasets", params={"limit": limit})
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
yield from [Dataset(**dataset) for dataset in response.json()]
|
|
||||||
|
|
||||||
@xor_args(("dataset_id", "dataset_name"))
|
|
||||||
def delete_dataset(
|
|
||||||
self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
|
||||||
) -> Dataset:
|
|
||||||
"""Delete a dataset by ID or name."""
|
|
||||||
if dataset_name is not None:
|
|
||||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
|
||||||
if dataset_id is None:
|
|
||||||
raise ValueError("Must provide either dataset name or ID")
|
|
||||||
response = requests.delete(
|
|
||||||
f"{self.api_url}/datasets/{dataset_id}",
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return Dataset(**response.json())
|
|
||||||
|
|
||||||
@xor_args(("dataset_id", "dataset_name"))
|
|
||||||
def create_example(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
dataset_id: Optional[UUID] = None,
|
|
||||||
dataset_name: Optional[str] = None,
|
|
||||||
created_at: Optional[datetime] = None,
|
|
||||||
outputs: Dict[str, Any] | None = None,
|
|
||||||
) -> Example:
|
|
||||||
"""Create a dataset example in the LangChain+ API."""
|
|
||||||
if dataset_id is None:
|
|
||||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"inputs": inputs,
|
|
||||||
"outputs": outputs,
|
|
||||||
"dataset_id": dataset_id,
|
|
||||||
}
|
|
||||||
if created_at:
|
|
||||||
data["created_at"] = created_at.isoformat()
|
|
||||||
example = ExampleCreate(**data)
|
|
||||||
response = requests.post(
|
|
||||||
f"{self.api_url}/examples", headers=self._headers, data=example.json()
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
result = response.json()
|
|
||||||
return Example(**result)
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def read_example(self, example_id: Union[str, UUID]) -> Example:
|
|
||||||
"""Read an example from the LangChain+ API."""
|
|
||||||
response = self._get(f"/examples/{example_id}")
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return Example(**response.json())
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def list_examples(
|
|
||||||
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
|
||||||
) -> Iterator[Example]:
|
|
||||||
"""List the datasets on the LangChain+ API."""
|
|
||||||
params = {}
|
|
||||||
if dataset_id is not None:
|
|
||||||
params["dataset"] = dataset_id
|
|
||||||
elif dataset_name is not None:
|
|
||||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
|
||||||
params["dataset"] = dataset_id
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
response = self._get("/examples", params=params)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
yield from [Example(**dataset) for dataset in response.json()]
|
|
||||||
|
|
||||||
def update_example(
|
|
||||||
self,
|
|
||||||
example_id: str,
|
|
||||||
*,
|
|
||||||
inputs: Optional[Dict[str, Any]] = None,
|
|
||||||
outputs: Optional[Mapping[str, Any]] = None,
|
|
||||||
dataset_id: Optional[str] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Update a specific example."""
|
|
||||||
example = ExampleUpdate(
|
|
||||||
inputs=inputs,
|
|
||||||
outputs=outputs,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
response = requests.patch(
|
|
||||||
f"{self.api_url}/examples/{example_id}",
|
|
||||||
headers=self._headers,
|
|
||||||
data=example.json(exclude_none=True),
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def create_feedback(
|
|
||||||
self,
|
|
||||||
run_id: str,
|
|
||||||
key: str,
|
|
||||||
*,
|
|
||||||
score: Union[float, int, bool, None] = None,
|
|
||||||
value: Union[float, int, bool, str, dict, None] = None,
|
|
||||||
correction: Union[str, dict, None] = None,
|
|
||||||
comment: Union[str, None] = None,
|
|
||||||
source_info: Optional[Dict[str, Any]] = None,
|
|
||||||
feedback_source_type: FeedbackSourceType = FeedbackSourceType.API,
|
|
||||||
) -> Feedback:
|
|
||||||
"""Create a feedback in the LangChain+ API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_id: The ID of the run to provide feedback on.
|
|
||||||
key: The name of the metric, tag, or 'aspect' this
|
|
||||||
feedback is about.
|
|
||||||
score: The score to rate this run on the metric
|
|
||||||
or aspect.
|
|
||||||
value: The display value or non-numeric value for this feedback.
|
|
||||||
correction: The proper ground truth for this run.
|
|
||||||
comment: A comment about this feedback.
|
|
||||||
source_info: Information about the source of this feedback.
|
|
||||||
feedback_source_type: The type of feedback source.
|
|
||||||
"""
|
|
||||||
if feedback_source_type == FeedbackSourceType.API:
|
|
||||||
feedback_source: FeedbackSourceBase = APIFeedbackSource(
|
|
||||||
metadata=source_info
|
|
||||||
)
|
|
||||||
elif feedback_source_type == FeedbackSourceType.MODEL:
|
|
||||||
feedback_source = ModelFeedbackSource(metadata=source_info)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown feedback source type {feedback_source_type}")
|
|
||||||
feedback = FeedbackCreate(
|
|
||||||
run_id=run_id,
|
|
||||||
key=key,
|
|
||||||
score=score,
|
|
||||||
value=value,
|
|
||||||
correction=correction,
|
|
||||||
comment=comment,
|
|
||||||
feedback_source=feedback_source,
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
self.api_url + "/feedback",
|
|
||||||
headers=self._headers,
|
|
||||||
data=feedback.json(),
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return Feedback(**feedback.dict())
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def read_feedback(self, feedback_id: str) -> Feedback:
|
|
||||||
"""Read a feedback from the LangChain+ API."""
|
|
||||||
response = self._get(f"/feedback/{feedback_id}")
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
return Feedback(**response.json())
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
||||||
def list_feedback(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
run_ids: Optional[Sequence[Union[str, UUID]]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[Feedback]:
|
|
||||||
"""List the feedback objects on the LangChain+ API."""
|
|
||||||
params = ListFeedbackQueryParams(
|
|
||||||
run=run_ids,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
response = self._get("/feedback", params=params.dict(exclude_none=True))
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
yield from [Feedback(**feedback) for feedback in response.json()]
|
|
||||||
|
|
||||||
def delete_feedback(self, feedback_id: str) -> None:
|
|
||||||
"""Delete a feedback by ID."""
|
|
||||||
response = requests.delete(
|
|
||||||
f"{self.api_url}/feedback/{feedback_id}",
|
|
||||||
headers=self._headers,
|
|
||||||
)
|
|
||||||
raise_for_status_with_text(response)
|
|
||||||
|
|
||||||
async def arun_on_dataset(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
||||||
*,
|
|
||||||
concurrency_level: int = 5,
|
|
||||||
num_repetitions: int = 1,
|
|
||||||
session_name: Optional[str] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Run the chain on a dataset and store traces to the specified session name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name: Name of the dataset to run the chain on.
|
|
||||||
llm_or_chain_factory: Language model or Chain constructor to run
|
|
||||||
over the dataset. The Chain constructor is used to permit
|
|
||||||
independent calls on each example without carrying over state.
|
|
||||||
concurrency_level: The number of async tasks to run concurrently.
|
|
||||||
num_repetitions: Number of times to run the model on each example.
|
|
||||||
This is useful when testing success rates or generating confidence
|
|
||||||
intervals.
|
|
||||||
session_name: Name of the session to store the traces in.
|
|
||||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
|
||||||
verbose: Whether to print progress.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary mapping example ids to the model outputs.
|
|
||||||
"""
|
|
||||||
session_name = LangChainPlusClient._get_session_name(
|
|
||||||
session_name, llm_or_chain_factory, dataset_name
|
|
||||||
)
|
|
||||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
|
||||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
|
||||||
|
|
||||||
return await arun_on_examples(
|
|
||||||
examples,
|
|
||||||
llm_or_chain_factory,
|
|
||||||
concurrency_level=concurrency_level,
|
|
||||||
num_repetitions=num_repetitions,
|
|
||||||
session_name=session_name,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_on_dataset(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
||||||
*,
|
|
||||||
num_repetitions: int = 1,
|
|
||||||
session_name: Optional[str] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Run the chain on a dataset and store traces to the specified session name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name: Name of the dataset to run the chain on.
|
|
||||||
llm_or_chain_factory: Language model or Chain constructor to run
|
|
||||||
over the dataset. The Chain constructor is used to permit
|
|
||||||
independent calls on each example without carrying over state.
|
|
||||||
concurrency_level: Number of async workers to run in parallel.
|
|
||||||
num_repetitions: Number of times to run the model on each example.
|
|
||||||
This is useful when testing success rates or generating confidence
|
|
||||||
intervals.
|
|
||||||
session_name: Name of the session to store the traces in.
|
|
||||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
|
||||||
verbose: Whether to print progress.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary mapping example ids to the model outputs.
|
|
||||||
"""
|
|
||||||
session_name = LangChainPlusClient._get_session_name(
|
|
||||||
session_name, llm_or_chain_factory, dataset_name
|
|
||||||
)
|
|
||||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
|
||||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
|
||||||
return run_on_examples(
|
|
||||||
examples,
|
|
||||||
llm_or_chain_factory,
|
|
||||||
num_repetitions=num_repetitions,
|
|
||||||
session_name=session_name,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
@ -1,206 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, Union
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
|
||||||
|
|
||||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleBase(BaseModel):
|
|
||||||
"""Example base model."""
|
|
||||||
|
|
||||||
dataset_id: UUID
|
|
||||||
inputs: Dict[str, Any]
|
|
||||||
outputs: Optional[Dict[str, Any]] = Field(default=None)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleCreate(ExampleBase):
|
|
||||||
"""Example create model."""
|
|
||||||
|
|
||||||
id: Optional[UUID]
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
|
|
||||||
|
|
||||||
class Example(ExampleBase):
|
|
||||||
"""Example model."""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
created_at: datetime
|
|
||||||
modified_at: Optional[datetime] = Field(default=None)
|
|
||||||
runs: List[Run] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleUpdate(BaseModel):
|
|
||||||
"""Update class for Example."""
|
|
||||||
|
|
||||||
dataset_id: Optional[UUID] = None
|
|
||||||
inputs: Optional[Dict[str, Any]] = None
|
|
||||||
outputs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetBase(BaseModel):
|
|
||||||
"""Dataset base model."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetCreate(DatasetBase):
|
|
||||||
"""Dataset create model."""
|
|
||||||
|
|
||||||
id: Optional[UUID]
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(DatasetBase):
|
|
||||||
"""Dataset ORM model."""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
tenant_id: UUID
|
|
||||||
created_at: datetime
|
|
||||||
modified_at: Optional[datetime] = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class ListRunsQueryParams(BaseModel):
|
|
||||||
"""Query params for GET /runs endpoint."""
|
|
||||||
|
|
||||||
id: Optional[List[UUID]]
|
|
||||||
"""Filter runs by id."""
|
|
||||||
parent_run: Optional[UUID]
|
|
||||||
"""Filter runs by parent run."""
|
|
||||||
run_type: Optional[RunTypeEnum]
|
|
||||||
"""Filter runs by type."""
|
|
||||||
session: Optional[UUID] = Field(default=None, alias="session_id")
|
|
||||||
"""Only return runs within a session."""
|
|
||||||
reference_example: Optional[UUID]
|
|
||||||
"""Only return runs that reference the specified dataset example."""
|
|
||||||
execution_order: Optional[int]
|
|
||||||
"""Filter runs by execution order."""
|
|
||||||
error: Optional[bool]
|
|
||||||
"""Whether to return only runs that errored."""
|
|
||||||
offset: Optional[int]
|
|
||||||
"""The offset of the first run to return."""
|
|
||||||
limit: Optional[int]
|
|
||||||
"""The maximum number of runs to return."""
|
|
||||||
start_time: Optional[datetime] = Field(
|
|
||||||
default=None,
|
|
||||||
alias="start_before",
|
|
||||||
description="Query Runs that started <= this time",
|
|
||||||
)
|
|
||||||
end_time: Optional[datetime] = Field(
|
|
||||||
default=None,
|
|
||||||
alias="end_after",
|
|
||||||
description="Query Runs that ended >= this time",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = "forbid"
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
@root_validator(allow_reuse=True)
|
|
||||||
def validate_time_range(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Validate that start_time <= end_time."""
|
|
||||||
start_time = values.get("start_time")
|
|
||||||
end_time = values.get("end_time")
|
|
||||||
if start_time and end_time and start_time > end_time:
|
|
||||||
raise ValueError("start_time must be <= end_time")
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackSourceBase(BaseModel):
|
|
||||||
type: ClassVar[str]
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
|
|
||||||
class APIFeedbackSource(FeedbackSourceBase):
|
|
||||||
"""API feedback source."""
|
|
||||||
|
|
||||||
type: ClassVar[str] = "api"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFeedbackSource(FeedbackSourceBase):
|
|
||||||
"""Model feedback source."""
|
|
||||||
|
|
||||||
type: ClassVar[str] = "model"
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackSourceType(Enum):
|
|
||||||
"""Feedback source type."""
|
|
||||||
|
|
||||||
API = "api"
|
|
||||||
"""General feedback submitted from the API."""
|
|
||||||
MODEL = "model"
|
|
||||||
"""Model-assisted feedback."""
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackBase(BaseModel):
|
|
||||||
"""Feedback schema."""
|
|
||||||
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
"""The time the feedback was created."""
|
|
||||||
modified_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
"""The time the feedback was last modified."""
|
|
||||||
run_id: UUID
|
|
||||||
"""The associated run ID this feedback is logged for."""
|
|
||||||
key: str
|
|
||||||
"""The metric name, tag, or aspect to provide feedback on."""
|
|
||||||
score: Union[float, int, bool, None] = None
|
|
||||||
"""Value or score to assign the run."""
|
|
||||||
value: Union[float, int, bool, str, dict, None] = None
|
|
||||||
"""The display value, tag or other value for the feedback if not a metric."""
|
|
||||||
comment: Optional[str] = None
|
|
||||||
"""Comment or explanation for the feedback."""
|
|
||||||
correction: Union[str, dict, None] = None
|
|
||||||
"""Correction for the run."""
|
|
||||||
feedback_source: Optional[
|
|
||||||
Union[APIFeedbackSource, ModelFeedbackSource, Mapping[str, Any]]
|
|
||||||
] = None
|
|
||||||
"""The source of the feedback."""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackCreate(FeedbackBase):
|
|
||||||
"""Schema used for creating feedback."""
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4)
|
|
||||||
|
|
||||||
feedback_source: APIFeedbackSource
|
|
||||||
"""The source of the feedback."""
|
|
||||||
|
|
||||||
|
|
||||||
class Feedback(FeedbackBase):
|
|
||||||
"""Schema for getting feedback."""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
feedback_source: Optional[Dict] = None
|
|
||||||
"""The source of the feedback. In this case"""
|
|
||||||
|
|
||||||
|
|
||||||
class ListFeedbackQueryParams(BaseModel):
|
|
||||||
"""Query Params for listing feedbacks."""
|
|
||||||
|
|
||||||
run: Optional[Sequence[UUID]] = None
|
|
||||||
limit: int = 100
|
|
||||||
offset: int = 0
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Config for query params."""
|
|
||||||
|
|
||||||
extra = "forbid"
|
|
||||||
frozen = True
|
|
@ -4,15 +4,18 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union
|
from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
from langchainplus_sdk import LangChainPlusClient
|
||||||
|
from langchainplus_sdk.schemas import Example
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.client.models import Example
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -372,3 +375,107 @@ def run_on_examples(
|
|||||||
print(f"{i+1} processed", flush=True, end="\r")
|
print(f"{i+1} processed", flush=True, end="\r")
|
||||||
results[str(example.id)] = result
|
results[str(example.id)] = result
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_name(
|
||||||
|
session_name: Optional[str],
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
dataset_name: str,
|
||||||
|
) -> str:
|
||||||
|
if session_name is not None:
|
||||||
|
return session_name
|
||||||
|
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||||
|
model_name = llm_or_chain_factory.__class__.__name__
|
||||||
|
else:
|
||||||
|
model_name = llm_or_chain_factory().__class__.__name__
|
||||||
|
return f"{dataset_name}-{model_name}-{current_time}"
|
||||||
|
|
||||||
|
|
||||||
|
async def arun_on_dataset(
|
||||||
|
dataset_name: str,
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
*,
|
||||||
|
concurrency_level: int = 5,
|
||||||
|
num_repetitions: int = 1,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
client: Optional[LangChainPlusClient] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Run the chain on a dataset and store traces to the specified session name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Client to use to read the dataset.
|
||||||
|
dataset_name: Name of the dataset to run the chain on.
|
||||||
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
|
over the dataset. The Chain constructor is used to permit
|
||||||
|
independent calls on each example without carrying over state.
|
||||||
|
concurrency_level: The number of async tasks to run concurrently.
|
||||||
|
num_repetitions: Number of times to run the model on each example.
|
||||||
|
This is useful when testing success rates or generating confidence
|
||||||
|
intervals.
|
||||||
|
session_name: Name of the session to store the traces in.
|
||||||
|
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||||
|
verbose: Whether to print progress.
|
||||||
|
client: Client to use to read the dataset. If not provided, a new
|
||||||
|
client will be created using the credentials in the environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping example ids to the model outputs.
|
||||||
|
"""
|
||||||
|
client_ = client or LangChainPlusClient()
|
||||||
|
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
|
||||||
|
dataset = client_.read_dataset(dataset_name=dataset_name)
|
||||||
|
examples = client_.list_examples(dataset_id=str(dataset.id))
|
||||||
|
|
||||||
|
return await arun_on_examples(
|
||||||
|
examples,
|
||||||
|
llm_or_chain_factory,
|
||||||
|
concurrency_level=concurrency_level,
|
||||||
|
num_repetitions=num_repetitions,
|
||||||
|
session_name=session_name,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_on_dataset(
|
||||||
|
dataset_name: str,
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
*,
|
||||||
|
num_repetitions: int = 1,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
client: Optional[LangChainPlusClient] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run the chain on a dataset and store traces to the specified session name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_name: Name of the dataset to run the chain on.
|
||||||
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
|
over the dataset. The Chain constructor is used to permit
|
||||||
|
independent calls on each example without carrying over state.
|
||||||
|
concurrency_level: Number of async workers to run in parallel.
|
||||||
|
num_repetitions: Number of times to run the model on each example.
|
||||||
|
This is useful when testing success rates or generating confidence
|
||||||
|
intervals.
|
||||||
|
session_name: Name of the session to store the traces in.
|
||||||
|
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||||
|
verbose: Whether to print progress.
|
||||||
|
client: Client to use to access the dataset. If None, a new client
|
||||||
|
will be created using the credentials in the environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping example ids to the model outputs.
|
||||||
|
"""
|
||||||
|
client_ = client or LangChainPlusClient()
|
||||||
|
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
|
||||||
|
dataset = client_.read_dataset(dataset_name=dataset_name)
|
||||||
|
examples = client_.list_examples(dataset_id=str(dataset.id))
|
||||||
|
return run_on_examples(
|
||||||
|
examples,
|
||||||
|
llm_or_chain_factory,
|
||||||
|
num_repetitions=num_repetitions,
|
||||||
|
session_name=session_name,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
@ -99,7 +99,8 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"from langchain.client import LangChainPlusClient\n",
|
"from langchainplus_sdk import LangChainPlusClient\n",
|
||||||
|
"from langchain.client import arun_on_dataset, run_on_dataset\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||||
"os.environ[\"LANGCHAIN_SESSION\"] = \"Tracing Walkthrough\"\n",
|
"os.environ[\"LANGCHAIN_SESSION\"] = \"Tracing Walkthrough\"\n",
|
||||||
@ -125,8 +126,10 @@
|
|||||||
"from langchain.agents import AgentType\n",
|
"from langchain.agents import AgentType\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = ChatOpenAI(temperature=0)\n",
|
"llm = ChatOpenAI(temperature=0)\n",
|
||||||
"tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n",
|
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
|
||||||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)"
|
"agent = initialize_agent(\n",
|
||||||
|
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -142,8 +145,9 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.\n",
|
"unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.\n",
|
||||||
"unknown format from LLM: Sorry, as an AI language model, I do not have access to personal information such as someone's age. Please provide a different math problem.\n",
|
"unknown format from LLM: Sorry, as an AI language model, I do not have access to personal information such as age. Please provide a valid math problem.\n",
|
||||||
"unknown format from LLM: As an AI language model, I do not have information on future events such as the 2023 super bowl. Therefore, I cannot provide a solution to this question.\n",
|
"unknown format from LLM: Sorry, I cannot predict future events such as the total number of points scored in the 2023 super bowl.\n",
|
||||||
|
"This model's maximum context length is 4097 tokens. However, your messages resulted in 4097 tokens. Please reduce the length of the messages.\n",
|
||||||
"unknown format from LLM: This is not a math problem and cannot be translated into a mathematical expression.\n"
|
"unknown format from LLM: This is not a math problem and cannot be translated into a mathematical expression.\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -151,12 +155,12 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"['The population of Canada as of 2023 is estimated to be 39,566,248.',\n",
|
"['The population of Canada as of 2023 is estimated to be 39,566,248.',\n",
|
||||||
" \"Anwar Hadid's age raised to the 0.43 power is approximately 3.87.\",\n",
|
" \"Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\",\n",
|
||||||
" ValueError(\"unknown format from LLM: Sorry, as an AI language model, I do not have access to personal information such as someone's age. Please provide a different math problem.\"),\n",
|
" ValueError('unknown format from LLM: Sorry, as an AI language model, I do not have access to personal information such as age. Please provide a valid math problem.'),\n",
|
||||||
" 'The distance between Paris and Boston is 3448 miles.',\n",
|
" 'The distance between Paris and Boston is 3448 miles.',\n",
|
||||||
" ValueError('unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.'),\n",
|
" ValueError('unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.'),\n",
|
||||||
" ValueError('unknown format from LLM: As an AI language model, I do not have information on future events such as the 2023 super bowl. Therefore, I cannot provide a solution to this question.'),\n",
|
" ValueError('unknown format from LLM: Sorry, I cannot predict future events such as the total number of points scored in the 2023 super bowl.'),\n",
|
||||||
" '15 points were scored more in the 2023 Super Bowl than in the 2022 Super Bowl.',\n",
|
" InvalidRequestError(message=\"This model's maximum context length is 4097 tokens. However, your messages resulted in 4097 tokens. Please reduce the length of the messages.\", param='messages', code='context_length_exceeded', http_status=400, request_id=None),\n",
|
||||||
" '1.9347796717823205',\n",
|
" '1.9347796717823205',\n",
|
||||||
" ValueError('unknown format from LLM: This is not a math problem and cannot be translated into a mathematical expression.'),\n",
|
" ValueError('unknown format from LLM: This is not a math problem and cannot be translated into a mathematical expression.'),\n",
|
||||||
" '0.2791714614499425']"
|
" '0.2791714614499425']"
|
||||||
@ -184,6 +188,7 @@
|
|||||||
"]\n",
|
"]\n",
|
||||||
"results = []\n",
|
"results = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"async def arun(agent, input_example):\n",
|
"async def arun(agent, input_example):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" return await agent.arun(input_example)\n",
|
" return await agent.arun(input_example)\n",
|
||||||
@ -191,9 +196,11 @@
|
|||||||
" # The agent sometimes makes mistakes! These will be captured by the tracing.\n",
|
" # The agent sometimes makes mistakes! These will be captured by the tracing.\n",
|
||||||
" print(e)\n",
|
" print(e)\n",
|
||||||
" return e\n",
|
" return e\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"for input_example in inputs:\n",
|
"for input_example in inputs:\n",
|
||||||
" results.append(arun(agent, input_example))\n",
|
" results.append(arun(agent, input_example))\n",
|
||||||
"await asyncio.gather(*results) "
|
"await asyncio.gather(*results)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -229,15 +236,19 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"if dataset_name in set([dataset.name for dataset in client.list_datasets()]):\n",
|
"if dataset_name in set([dataset.name for dataset in client.list_datasets()]):\n",
|
||||||
" client.delete_dataset(dataset_name=dataset_name)\n",
|
" client.delete_dataset(dataset_name=dataset_name)\n",
|
||||||
"dataset = client.create_dataset(dataset_name, description=\"A calculator example dataset\")\n",
|
"dataset = client.create_dataset(\n",
|
||||||
|
" dataset_name, description=\"A calculator example dataset\"\n",
|
||||||
|
")\n",
|
||||||
"runs = client.list_runs(\n",
|
"runs = client.list_runs(\n",
|
||||||
" session_name=os.environ[\"LANGCHAIN_SESSION\"],\n",
|
" session_name=os.environ[\"LANGCHAIN_SESSION\"],\n",
|
||||||
" execution_order=1, # Only return the top-level runs\n",
|
" execution_order=1, # Only return the top-level runs\n",
|
||||||
" error=False, # Only runs that succeed\n",
|
" error=False, # Only runs that succeed\n",
|
||||||
")\n",
|
")\n",
|
||||||
"for run in runs:\n",
|
"for run in runs:\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" client.create_example(inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id)\n",
|
" client.create_example(\n",
|
||||||
|
" inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id\n",
|
||||||
|
" )\n",
|
||||||
" except:\n",
|
" except:\n",
|
||||||
" pass"
|
" pass"
|
||||||
]
|
]
|
||||||
@ -298,7 +309,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# dataset = load_dataset(\"agent-search-calculator\")\n",
|
"# dataset = load_dataset(\"agent-search-calculator\")\n",
|
||||||
"# df = pd.DataFrame(dataset, columns=[\"question\", \"answer\"])\n",
|
"# df = pd.DataFrame(dataset, columns=[\"question\", \"answer\"])\n",
|
||||||
"# df.columns = [\"input\", \"output\"] # The chain we want to evaluate below expects inputs with the \"input\" key \n",
|
"# df.columns = [\"input\", \"output\"] # The chain we want to evaluate below expects inputs with the \"input\" key\n",
|
||||||
"# df.head()"
|
"# df.head()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -314,7 +325,7 @@
|
|||||||
"# dataset_name = \"calculator-example-dataset\"\n",
|
"# dataset_name = \"calculator-example-dataset\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# if dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
|
"# if dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
|
||||||
"# dataset = client.upload_dataframe(df, \n",
|
"# dataset = client.upload_dataframe(df,\n",
|
||||||
"# name=dataset_name,\n",
|
"# name=dataset_name,\n",
|
||||||
"# description=\"A calculator example dataset\",\n",
|
"# description=\"A calculator example dataset\",\n",
|
||||||
"# input_keys=[\"input\"],\n",
|
"# input_keys=[\"input\"],\n",
|
||||||
@ -352,8 +363,10 @@
|
|||||||
"from langchain.agents import AgentType\n",
|
"from langchain.agents import AgentType\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = ChatOpenAI(temperature=0)\n",
|
"llm = ChatOpenAI(temperature=0)\n",
|
||||||
"tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n",
|
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
|
||||||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)"
|
"agent = initialize_agent(\n",
|
||||||
|
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -378,7 +391,7 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"\u001b[0;31mSignature:\u001b[0m\n",
|
"\u001b[0;31mSignature:\u001b[0m\n",
|
||||||
"\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
"\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
@ -386,11 +399,13 @@
|
|||||||
"\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
|
"\u001b[0;34m\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[LangChainPlusClient]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||||
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
"\u001b[0;31mDocstring:\u001b[0m\n",
|
"\u001b[0;31mDocstring:\u001b[0m\n",
|
||||||
"Run the chain on a dataset and store traces to the specified session name.\n",
|
"Run the chain on a dataset and store traces to the specified session name.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Args:\n",
|
"Args:\n",
|
||||||
|
" client: Client to use to read the dataset.\n",
|
||||||
" dataset_name: Name of the dataset to run the chain on.\n",
|
" dataset_name: Name of the dataset to run the chain on.\n",
|
||||||
" llm_or_chain_factory: Language model or Chain constructor to run\n",
|
" llm_or_chain_factory: Language model or Chain constructor to run\n",
|
||||||
" over the dataset. The Chain constructor is used to permit\n",
|
" over the dataset. The Chain constructor is used to permit\n",
|
||||||
@ -402,11 +417,13 @@
|
|||||||
" session_name: Name of the session to store the traces in.\n",
|
" session_name: Name of the session to store the traces in.\n",
|
||||||
" Defaults to {dataset_name}-{chain class name}-{datetime}.\n",
|
" Defaults to {dataset_name}-{chain class name}-{datetime}.\n",
|
||||||
" verbose: Whether to print progress.\n",
|
" verbose: Whether to print progress.\n",
|
||||||
|
" client: Client to use to read the dataset. If not provided, a new\n",
|
||||||
|
" client will be created using the credentials in the environment.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Returns:\n",
|
"Returns:\n",
|
||||||
" A dictionary mapping example ids to the model outputs.\n",
|
" A dictionary mapping example ids to the model outputs.\n",
|
||||||
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/langchain.py\n",
|
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/runner_utils.py\n",
|
||||||
"\u001b[0;31mType:\u001b[0m method"
|
"\u001b[0;31mType:\u001b[0m function"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -414,7 +431,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"?client.arun_on_dataset"
|
"?arun_on_dataset"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -453,62 +470,32 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Processed examples: 4\r"
|
"Processed examples: 1\r"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Chain failed for example 898af6aa-ea39-4959-9ecd-9b9f1ffee31c. Error: LLMMathChain._evaluate(\"\n",
|
"Chain failed for example c6bb978e-b393-4f70-b63b-b0fb03a32dc2. Error: This model's maximum context length is 4097 tokens. However, your messages resulted in 4097 tokens. Please reduce the length of the messages.\n"
|
||||||
"round(0.2791714614499425, 2)\n",
|
|
||||||
"\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Processed examples: 5\r"
|
"Processed examples: 9\r"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Chain failed for example ffb8071d-60e4-49ca-aa9f-5ec03ea78f2d. Error: unknown format from LLM: This is not a math problem and cannot be translated into a mathematical expression.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Processed examples: 6\r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 29fc448d09a0f240719eb1dbb95db18d in your message.).\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Processed examples: 7\r"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"evaluation_session_name = \"Search + Calculator Agent Evaluation\"\n",
|
"evaluation_session_name = \"Search + Calculator Agent Evaluation\"\n",
|
||||||
"chain_results = await client.arun_on_dataset(\n",
|
"chain_results = await arun_on_dataset(\n",
|
||||||
" dataset_name=dataset_name,\n",
|
" dataset_name=dataset_name,\n",
|
||||||
" llm_or_chain_factory=chain_factory,\n",
|
" llm_or_chain_factory=chain_factory,\n",
|
||||||
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
|
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
|
||||||
" verbose=True,\n",
|
" verbose=True,\n",
|
||||||
" session_name=evaluation_session_name # Optional, a unique session name will be generated if not provided\n",
|
" session_name=evaluation_session_name, # Optional, a unique session name will be generated if not provided\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n",
|
"# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n",
|
||||||
@ -593,22 +580,22 @@
|
|||||||
"examples = []\n",
|
"examples = []\n",
|
||||||
"predictions = []\n",
|
"predictions = []\n",
|
||||||
"run_ids = []\n",
|
"run_ids = []\n",
|
||||||
"for run in client.list_runs(session_name=evaluation_session_name, execution_order=1, error=False):\n",
|
"for run in client.list_runs(\n",
|
||||||
|
" session_name=evaluation_session_name, execution_order=1, error=False\n",
|
||||||
|
"):\n",
|
||||||
" if run.reference_example_id is None or not run.outputs:\n",
|
" if run.reference_example_id is None or not run.outputs:\n",
|
||||||
" continue\n",
|
" continue\n",
|
||||||
" run_ids.append(run.id)\n",
|
" run_ids.append(run.id)\n",
|
||||||
" example = client.read_example(run.reference_example_id)\n",
|
" example = client.read_example(run.reference_example_id)\n",
|
||||||
" examples.append({**run.inputs, **example.outputs})\n",
|
" examples.append({**run.inputs, **example.outputs})\n",
|
||||||
" predictions.append(\n",
|
" predictions.append(run.outputs)\n",
|
||||||
" run.outputs\n",
|
"\n",
|
||||||
" )\n",
|
|
||||||
" \n",
|
|
||||||
"evaluation_results = chain.evaluate(\n",
|
"evaluation_results = chain.evaluate(\n",
|
||||||
" examples,\n",
|
" examples,\n",
|
||||||
" predictions,\n",
|
" predictions,\n",
|
||||||
" question_key=\"input\",\n",
|
" question_key=\"input\",\n",
|
||||||
" answer_key=\"output\",\n",
|
" answer_key=\"output\",\n",
|
||||||
" prediction_key=\"output\"\n",
|
" prediction_key=\"output\",\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -668,7 +655,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.2"
|
"version": "3.11.3"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
2072
poetry.lock
generated
2072
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,6 @@ repository = "https://www.github.com/hwchase17/langchain"
|
|||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
langchain-server = "langchain.server:main"
|
langchain-server = "langchain.server:main"
|
||||||
langchain = "langchain.cli.main:main"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
@ -104,6 +103,7 @@ momento = {version = "^1.5.0", optional = true}
|
|||||||
bibtexparser = {version = "^1.4.0", optional = true}
|
bibtexparser = {version = "^1.4.0", optional = true}
|
||||||
pyspark = {version = "^3.4.0", optional = true}
|
pyspark = {version = "^3.4.0", optional = true}
|
||||||
tigrisdb = {version = "^1.0.0b6", optional = true}
|
tigrisdb = {version = "^1.0.0b6", optional = true}
|
||||||
|
langchainplus-sdk = "^0.0.4"
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
autodoc_pydantic = "^1.8.0"
|
autodoc_pydantic = "^1.8.0"
|
||||||
|
@ -1,116 +0,0 @@
|
|||||||
"""LangChain+ langchain_client Integration Tests."""
|
|
||||||
import os
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from tenacity import RetryError
|
|
||||||
|
|
||||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
|
||||||
from langchain.callbacks.manager import tracing_v2_enabled
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain.client import LangChainPlusClient
|
|
||||||
from langchain.tools.base import tool
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def langchain_client(monkeypatch: pytest.MonkeyPatch) -> LangChainPlusClient:
|
|
||||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
|
||||||
return LangChainPlusClient()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sessions(
|
|
||||||
langchain_client: LangChainPlusClient, monkeypatch: pytest.MonkeyPatch
|
|
||||||
) -> None:
|
|
||||||
"""Test sessions."""
|
|
||||||
session_names = set([session.name for session in langchain_client.list_sessions()])
|
|
||||||
new_session = f"Session {uuid4()}"
|
|
||||||
assert new_session not in session_names
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def example_tool() -> str:
|
|
||||||
"""Call me, maybe."""
|
|
||||||
return "test_tool"
|
|
||||||
|
|
||||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
|
||||||
with tracing_v2_enabled(session_name=new_session):
|
|
||||||
example_tool({})
|
|
||||||
session = langchain_client.read_session(session_name=new_session)
|
|
||||||
assert session.name == new_session
|
|
||||||
session_names = set([sess.name for sess in langchain_client.list_sessions()])
|
|
||||||
assert new_session in session_names
|
|
||||||
runs = list(langchain_client.list_runs(session_name=new_session))
|
|
||||||
session_id_runs = list(langchain_client.list_runs(session_id=session.id))
|
|
||||||
assert len(runs) == len(session_id_runs) == 1
|
|
||||||
assert runs[0].id == session_id_runs[0].id
|
|
||||||
langchain_client.delete_session(session_name=new_session)
|
|
||||||
|
|
||||||
with pytest.raises(RetryError):
|
|
||||||
langchain_client.read_session(session_name=new_session)
|
|
||||||
assert new_session not in set(
|
|
||||||
[sess.name for sess in langchain_client.list_sessions()]
|
|
||||||
)
|
|
||||||
with pytest.raises(RetryError):
|
|
||||||
langchain_client.delete_session(session_name=new_session)
|
|
||||||
with pytest.raises(RetryError):
|
|
||||||
langchain_client.read_run(run_id=str(runs[0].id))
|
|
||||||
|
|
||||||
|
|
||||||
def test_feedback_cycle(
|
|
||||||
monkeypatch: pytest.MonkeyPatch, langchain_client: LangChainPlusClient
|
|
||||||
) -> None:
|
|
||||||
"""Test that feedback is correctly created and updated."""
|
|
||||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
|
|
||||||
monkeypatch.setenv("LANGCHAIN_SESSION", f"Feedback Testing {uuid4()}")
|
|
||||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
|
||||||
llm = ChatOpenAI(temperature=0)
|
|
||||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
|
||||||
agent = initialize_agent(
|
|
||||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
agent.run(
|
|
||||||
"What is the population of Kuala Lumpur as of January, 2023?"
|
|
||||||
" What is it's square root?"
|
|
||||||
)
|
|
||||||
other_session_name = f"Feedback Testing {uuid4()}"
|
|
||||||
with tracing_v2_enabled(session_name=other_session_name):
|
|
||||||
try:
|
|
||||||
agent.run("What is the square root of 3?")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
runs = list(
|
|
||||||
langchain_client.list_runs(
|
|
||||||
session_name=os.environ["LANGCHAIN_SESSION"], error=False, execution_order=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert len(runs) == 1
|
|
||||||
order_2 = list(
|
|
||||||
langchain_client.list_runs(
|
|
||||||
session_name=os.environ["LANGCHAIN_SESSION"], execution_order=2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert len(order_2) > 0
|
|
||||||
langchain_client.create_feedback(str(order_2[0].id), "test score", score=0)
|
|
||||||
feedback = langchain_client.create_feedback(str(runs[0].id), "test score", score=1)
|
|
||||||
feedbacks = list(langchain_client.list_feedback(run_ids=[str(runs[0].id)]))
|
|
||||||
assert len(feedbacks) == 1
|
|
||||||
assert feedbacks[0].id == feedback.id
|
|
||||||
|
|
||||||
# Add feedback to other session
|
|
||||||
other_runs = list(
|
|
||||||
langchain_client.list_runs(session_name=other_session_name, execution_order=1)
|
|
||||||
)
|
|
||||||
assert len(other_runs) == 1
|
|
||||||
langchain_client.create_feedback(
|
|
||||||
run_id=str(other_runs[0].id), key="test score", score=0
|
|
||||||
)
|
|
||||||
all_runs = list(
|
|
||||||
langchain_client.list_runs(session_name=os.environ["LANGCHAIN_SESSION"])
|
|
||||||
) + list(langchain_client.list_runs(session_name=other_session_name))
|
|
||||||
test_run_ids = [str(run.id) for run in all_runs]
|
|
||||||
all_feedback = list(langchain_client.list_feedback(run_ids=test_run_ids))
|
|
||||||
assert len(all_feedback) == 3
|
|
||||||
for feedback in all_feedback:
|
|
||||||
langchain_client.delete_feedback(str(feedback.id))
|
|
||||||
feedbacks = list(langchain_client.list_feedback(run_ids=test_run_ids))
|
|
||||||
assert len(feedbacks) == 0
|
|
@ -1,207 +0,0 @@
|
|||||||
"""Test the LangChain+ client."""
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.client.langchain import (
|
|
||||||
LangChainPlusClient,
|
|
||||||
_get_link_stem,
|
|
||||||
_is_localhost,
|
|
||||||
)
|
|
||||||
from langchain.client.models import Dataset, Example
|
|
||||||
|
|
||||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
|
||||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"api_url, expected_url",
|
|
||||||
[
|
|
||||||
("http://localhost:8000", "http://localhost"),
|
|
||||||
("http://www.example.com", "http://www.example.com"),
|
|
||||||
(
|
|
||||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
|
||||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
|
||||||
),
|
|
||||||
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_link_split(api_url: str, expected_url: str) -> None:
|
|
||||||
"""Test the link splitting handles both localhost and deployed urls."""
|
|
||||||
assert _get_link_stem(api_url) == expected_url
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_localhost() -> None:
|
|
||||||
assert _is_localhost("http://localhost:8000")
|
|
||||||
assert _is_localhost("http://127.0.0.1:8000")
|
|
||||||
assert _is_localhost("http://0.0.0.0:8000")
|
|
||||||
assert not _is_localhost("http://example.com:8000")
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_api_key_if_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
|
||||||
with pytest.raises(ValueError, match="API key must be provided"):
|
|
||||||
LangChainPlusClient(api_url="http://www.example.com")
|
|
||||||
|
|
||||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
|
||||||
assert client.api_url == "http://localhost:8000"
|
|
||||||
assert client.api_key is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_headers(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
|
|
||||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
|
||||||
assert client._headers == {"x-api-key": "123"}
|
|
||||||
|
|
||||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
|
||||||
assert client_no_key._headers == {}
|
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("langchain.client.langchain.requests.post")
|
|
||||||
def test_upload_csv(mock_post: mock.Mock) -> None:
|
|
||||||
mock_response = mock.Mock()
|
|
||||||
dataset_id = str(uuid.uuid4())
|
|
||||||
example_1 = Example(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "1"},
|
|
||||||
outputs={"output": "2"},
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
example_2 = Example(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "3"},
|
|
||||||
outputs={"output": "4"},
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"id": dataset_id,
|
|
||||||
"name": "test.csv",
|
|
||||||
"description": "Test dataset",
|
|
||||||
"owner_id": "the owner",
|
|
||||||
"created_at": _CREATED_AT,
|
|
||||||
"examples": [example_1, example_2],
|
|
||||||
"tenant_id": _TENANT_ID,
|
|
||||||
}
|
|
||||||
mock_post.return_value = mock_response
|
|
||||||
|
|
||||||
client = LangChainPlusClient(
|
|
||||||
api_url="http://localhost:8000",
|
|
||||||
api_key="123",
|
|
||||||
)
|
|
||||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
|
||||||
|
|
||||||
dataset = client.upload_csv(
|
|
||||||
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert dataset.id == uuid.UUID(dataset_id)
|
|
||||||
assert dataset.name == "test.csv"
|
|
||||||
assert dataset.description == "Test dataset"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
dataset = Dataset(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
name="test",
|
|
||||||
description="Test dataset",
|
|
||||||
owner_id="owner",
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
tenant_id=_TENANT_ID,
|
|
||||||
)
|
|
||||||
uuids = [
|
|
||||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
|
||||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
|
||||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
|
||||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
|
||||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
|
||||||
]
|
|
||||||
examples = [
|
|
||||||
Example(
|
|
||||||
id=uuids[0],
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "1"},
|
|
||||||
outputs={"output": "2"},
|
|
||||||
dataset_id=str(uuid.uuid4()),
|
|
||||||
),
|
|
||||||
Example(
|
|
||||||
id=uuids[1],
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "3"},
|
|
||||||
outputs={"output": "4"},
|
|
||||||
dataset_id=str(uuid.uuid4()),
|
|
||||||
),
|
|
||||||
Example(
|
|
||||||
id=uuids[2],
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "5"},
|
|
||||||
outputs={"output": "6"},
|
|
||||||
dataset_id=str(uuid.uuid4()),
|
|
||||||
),
|
|
||||||
Example(
|
|
||||||
id=uuids[3],
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "7"},
|
|
||||||
outputs={"output": "8"},
|
|
||||||
dataset_id=str(uuid.uuid4()),
|
|
||||||
),
|
|
||||||
Example(
|
|
||||||
id=uuids[4],
|
|
||||||
created_at=_CREATED_AT,
|
|
||||||
inputs={"input": "9"},
|
|
||||||
outputs={"output": "10"},
|
|
||||||
dataset_id=str(uuid.uuid4()),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
|
||||||
return examples
|
|
||||||
|
|
||||||
async def mock_arun_chain(
|
|
||||||
example: Example,
|
|
||||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
|
||||||
n_repetitions: int,
|
|
||||||
tracer: Any,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
return [
|
|
||||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
|
||||||
]
|
|
||||||
|
|
||||||
with mock.patch.object(
|
|
||||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
|
||||||
), mock.patch.object(
|
|
||||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
|
||||||
), mock.patch(
|
|
||||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
|
||||||
):
|
|
||||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
|
||||||
chain = mock.MagicMock()
|
|
||||||
num_repetitions = 3
|
|
||||||
results = await client.arun_on_dataset(
|
|
||||||
dataset_name="test",
|
|
||||||
llm_or_chain_factory=lambda: chain,
|
|
||||||
concurrency_level=2,
|
|
||||||
session_name="test_session",
|
|
||||||
num_repetitions=num_repetitions,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
uuid_: [
|
|
||||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
|
||||||
for _ in range(num_repetitions)
|
|
||||||
]
|
|
||||||
for uuid_ in uuids
|
|
||||||
}
|
|
||||||
assert results == expected
|
|
@ -1,18 +1,27 @@
|
|||||||
"""Test the LangChain+ client."""
|
"""Test the LangChain+ client."""
|
||||||
from typing import Any, Dict
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchainplus_sdk.client import LangChainPlusClient
|
||||||
|
from langchainplus_sdk.schemas import Dataset, Example
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.base import Chain
|
||||||
from langchain.client.runner_utils import (
|
from langchain.client.runner_utils import (
|
||||||
InputFormatError,
|
InputFormatError,
|
||||||
_get_messages,
|
_get_messages,
|
||||||
_get_prompts,
|
_get_prompts,
|
||||||
|
arun_on_dataset,
|
||||||
run_llm,
|
run_llm,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||||
|
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||||
_EXAMPLE_MESSAGE = {
|
_EXAMPLE_MESSAGE = {
|
||||||
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||||
"type": "human",
|
"type": "human",
|
||||||
@ -93,3 +102,103 @@ def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
|||||||
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||||
llm = FakeChatModel()
|
llm = FakeChatModel()
|
||||||
run_llm(llm, inputs, mock.MagicMock())
|
run_llm(llm, inputs, mock.MagicMock())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
dataset = Dataset(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="test",
|
||||||
|
description="Test dataset",
|
||||||
|
owner_id="owner",
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
tenant_id=_TENANT_ID,
|
||||||
|
)
|
||||||
|
uuids = [
|
||||||
|
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||||
|
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||||
|
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||||
|
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||||
|
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||||
|
]
|
||||||
|
examples = [
|
||||||
|
Example(
|
||||||
|
id=uuids[0],
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
inputs={"input": "1"},
|
||||||
|
outputs={"output": "2"},
|
||||||
|
dataset_id=str(uuid.uuid4()),
|
||||||
|
),
|
||||||
|
Example(
|
||||||
|
id=uuids[1],
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
inputs={"input": "3"},
|
||||||
|
outputs={"output": "4"},
|
||||||
|
dataset_id=str(uuid.uuid4()),
|
||||||
|
),
|
||||||
|
Example(
|
||||||
|
id=uuids[2],
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
inputs={"input": "5"},
|
||||||
|
outputs={"output": "6"},
|
||||||
|
dataset_id=str(uuid.uuid4()),
|
||||||
|
),
|
||||||
|
Example(
|
||||||
|
id=uuids[3],
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
inputs={"input": "7"},
|
||||||
|
outputs={"output": "8"},
|
||||||
|
dataset_id=str(uuid.uuid4()),
|
||||||
|
),
|
||||||
|
Example(
|
||||||
|
id=uuids[4],
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
inputs={"input": "9"},
|
||||||
|
outputs={"output": "10"},
|
||||||
|
dataset_id=str(uuid.uuid4()),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
||||||
|
return examples
|
||||||
|
|
||||||
|
async def mock_arun_chain(
|
||||||
|
example: Example,
|
||||||
|
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||||
|
n_repetitions: int,
|
||||||
|
tracer: Any,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
return [
|
||||||
|
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||||
|
]
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||||
|
), mock.patch.object(
|
||||||
|
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||||
|
), mock.patch(
|
||||||
|
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||||
|
):
|
||||||
|
client = LangChainPlusClient(api_url="http://localhost:1984", api_key="123")
|
||||||
|
chain = mock.MagicMock()
|
||||||
|
num_repetitions = 3
|
||||||
|
results = await arun_on_dataset(
|
||||||
|
dataset_name="test",
|
||||||
|
llm_or_chain_factory=lambda: chain,
|
||||||
|
concurrency_level=2,
|
||||||
|
session_name="test_session",
|
||||||
|
num_repetitions=num_repetitions,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
uuid_: [
|
||||||
|
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||||
|
for _ in range(num_repetitions)
|
||||||
|
]
|
||||||
|
for uuid_ in uuids
|
||||||
|
}
|
||||||
|
assert results == expected
|
||||||
|
@ -38,6 +38,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
|||||||
"aiohttp",
|
"aiohttp",
|
||||||
"async-timeout",
|
"async-timeout",
|
||||||
"dataclasses-json",
|
"dataclasses-json",
|
||||||
|
"langchainplus-sdk",
|
||||||
"numexpr",
|
"numexpr",
|
||||||
"numpy",
|
"numpy",
|
||||||
"openapi-schema-pydantic",
|
"openapi-schema-pydantic",
|
Loading…
Reference in New Issue
Block a user