forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
5.3 KiB
Python
207 lines
5.3 KiB
Python
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, Union
|
|
from uuid import UUID, uuid4
|
|
|
|
from pydantic import BaseModel, Field, root_validator
|
|
|
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
|
|
|
|
|
class ExampleBase(BaseModel):
|
|
"""Example base model."""
|
|
|
|
dataset_id: UUID
|
|
inputs: Dict[str, Any]
|
|
outputs: Optional[Dict[str, Any]] = Field(default=None)
|
|
|
|
class Config:
|
|
frozen = True
|
|
|
|
|
|
class ExampleCreate(ExampleBase):
|
|
"""Example create model."""
|
|
|
|
id: Optional[UUID]
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
|
|
class Example(ExampleBase):
|
|
"""Example model."""
|
|
|
|
id: UUID
|
|
created_at: datetime
|
|
modified_at: Optional[datetime] = Field(default=None)
|
|
runs: List[Run] = Field(default_factory=list)
|
|
|
|
|
|
class ExampleUpdate(BaseModel):
|
|
"""Update class for Example."""
|
|
|
|
dataset_id: Optional[UUID] = None
|
|
inputs: Optional[Dict[str, Any]] = None
|
|
outputs: Optional[Dict[str, Any]] = None
|
|
|
|
class Config:
|
|
frozen = True
|
|
|
|
|
|
class DatasetBase(BaseModel):
|
|
"""Dataset base model."""
|
|
|
|
tenant_id: UUID
|
|
name: str
|
|
description: Optional[str] = None
|
|
|
|
class Config:
|
|
frozen = True
|
|
|
|
|
|
class DatasetCreate(DatasetBase):
|
|
"""Dataset create model."""
|
|
|
|
id: Optional[UUID]
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
|
|
class Dataset(DatasetBase):
|
|
"""Dataset ORM model."""
|
|
|
|
id: UUID
|
|
created_at: datetime
|
|
modified_at: Optional[datetime] = Field(default=None)
|
|
|
|
|
|
class ListRunsQueryParams(BaseModel):
|
|
"""Query params for GET /runs endpoint."""
|
|
|
|
id: Optional[List[UUID]]
|
|
"""Filter runs by id."""
|
|
parent_run: Optional[UUID]
|
|
"""Filter runs by parent run."""
|
|
run_type: Optional[RunTypeEnum]
|
|
"""Filter runs by type."""
|
|
session: Optional[UUID] = Field(default=None, alias="session_id")
|
|
"""Only return runs within a session."""
|
|
reference_example: Optional[UUID]
|
|
"""Only return runs that reference the specified dataset example."""
|
|
execution_order: Optional[int]
|
|
"""Filter runs by execution order."""
|
|
error: Optional[bool]
|
|
"""Whether to return only runs that errored."""
|
|
offset: Optional[int]
|
|
"""The offset of the first run to return."""
|
|
limit: Optional[int]
|
|
"""The maximum number of runs to return."""
|
|
start_time: Optional[datetime] = Field(
|
|
default=None,
|
|
alias="start_before",
|
|
description="Query Runs that started <= this time",
|
|
)
|
|
end_time: Optional[datetime] = Field(
|
|
default=None,
|
|
alias="end_after",
|
|
description="Query Runs that ended >= this time",
|
|
)
|
|
|
|
class Config:
|
|
extra = "forbid"
|
|
frozen = True
|
|
|
|
@root_validator(allow_reuse=True)
|
|
def validate_time_range(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate that start_time <= end_time."""
|
|
start_time = values.get("start_time")
|
|
end_time = values.get("end_time")
|
|
if start_time and end_time and start_time > end_time:
|
|
raise ValueError("start_time must be <= end_time")
|
|
return values
|
|
|
|
|
|
class FeedbackSourceBase(BaseModel):
|
|
type: ClassVar[str]
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
class Config:
|
|
frozen = True
|
|
|
|
|
|
class APIFeedbackSource(FeedbackSourceBase):
|
|
"""API feedback source."""
|
|
|
|
type: ClassVar[str] = "api"
|
|
|
|
|
|
class ModelFeedbackSource(FeedbackSourceBase):
|
|
"""Model feedback source."""
|
|
|
|
type: ClassVar[str] = "model"
|
|
|
|
|
|
class FeedbackSourceType(Enum):
|
|
"""Feedback source type."""
|
|
|
|
API = "api"
|
|
"""General feedback submitted from the API."""
|
|
MODEL = "model"
|
|
"""Model-assisted feedback."""
|
|
|
|
|
|
class FeedbackBase(BaseModel):
|
|
"""Feedback schema."""
|
|
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
"""The time the feedback was created."""
|
|
modified_at: datetime = Field(default_factory=datetime.utcnow)
|
|
"""The time the feedback was last modified."""
|
|
run_id: UUID
|
|
"""The associated run ID this feedback is logged for."""
|
|
key: str
|
|
"""The metric name, tag, or aspect to provide feedback on."""
|
|
score: Union[float, int, bool, None] = None
|
|
"""Value or score to assign the run."""
|
|
value: Union[float, int, bool, str, dict, None] = None
|
|
"""The display value, tag or other value for the feedback if not a metric."""
|
|
comment: Optional[str] = None
|
|
"""Comment or explanation for the feedback."""
|
|
correction: Union[str, dict, None] = None
|
|
"""Correction for the run."""
|
|
feedback_source: Optional[
|
|
Union[APIFeedbackSource, ModelFeedbackSource, Mapping[str, Any]]
|
|
] = None
|
|
"""The source of the feedback."""
|
|
|
|
class Config:
|
|
frozen = True
|
|
|
|
|
|
class FeedbackCreate(FeedbackBase):
|
|
"""Schema used for creating feedback."""
|
|
|
|
id: UUID = Field(default_factory=uuid4)
|
|
|
|
feedback_source: APIFeedbackSource
|
|
"""The source of the feedback."""
|
|
|
|
|
|
class Feedback(FeedbackBase):
|
|
"""Schema for getting feedback."""
|
|
|
|
id: UUID
|
|
feedback_source: Optional[Dict] = None
|
|
"""The source of the feedback. In this case"""
|
|
|
|
|
|
class ListFeedbackQueryParams(BaseModel):
|
|
"""Query Params for listing feedbacks."""
|
|
|
|
run: Optional[Sequence[UUID]] = None
|
|
limit: int = 100
|
|
offset: int = 0
|
|
|
|
class Config:
|
|
"""Config for query params."""
|
|
|
|
extra = "forbid"
|
|
frozen = True
|