You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/client/models.py

207 lines
5.3 KiB
Python

from datetime import datetime
from enum import Enum
from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, Union
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, root_validator
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
class ExampleBase(BaseModel):
"""Example base model."""
dataset_id: UUID
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = Field(default=None)
class Config:
frozen = True
class ExampleCreate(ExampleBase):
"""Example create model."""
id: Optional[UUID]
created_at: datetime = Field(default_factory=datetime.utcnow)
class Example(ExampleBase):
"""Example model."""
id: UUID
created_at: datetime
modified_at: Optional[datetime] = Field(default=None)
runs: List[Run] = Field(default_factory=list)
class ExampleUpdate(BaseModel):
"""Update class for Example."""
dataset_id: Optional[UUID] = None
inputs: Optional[Dict[str, Any]] = None
outputs: Optional[Dict[str, Any]] = None
class Config:
frozen = True
class DatasetBase(BaseModel):
"""Dataset base model."""
tenant_id: UUID
name: str
description: Optional[str] = None
class Config:
frozen = True
class DatasetCreate(DatasetBase):
"""Dataset create model."""
id: Optional[UUID]
created_at: datetime = Field(default_factory=datetime.utcnow)
class Dataset(DatasetBase):
"""Dataset ORM model."""
id: UUID
created_at: datetime
modified_at: Optional[datetime] = Field(default=None)
class ListRunsQueryParams(BaseModel):
"""Query params for GET /runs endpoint."""
id: Optional[List[UUID]]
"""Filter runs by id."""
parent_run: Optional[UUID]
"""Filter runs by parent run."""
run_type: Optional[RunTypeEnum]
"""Filter runs by type."""
session: Optional[UUID] = Field(default=None, alias="session_id")
"""Only return runs within a session."""
reference_example: Optional[UUID]
"""Only return runs that reference the specified dataset example."""
execution_order: Optional[int]
"""Filter runs by execution order."""
error: Optional[bool]
"""Whether to return only runs that errored."""
offset: Optional[int]
"""The offset of the first run to return."""
limit: Optional[int]
"""The maximum number of runs to return."""
start_time: Optional[datetime] = Field(
default=None,
alias="start_before",
description="Query Runs that started <= this time",
)
end_time: Optional[datetime] = Field(
default=None,
alias="end_after",
description="Query Runs that ended >= this time",
)
class Config:
extra = "forbid"
frozen = True
@root_validator(allow_reuse=True)
def validate_time_range(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that start_time <= end_time."""
start_time = values.get("start_time")
end_time = values.get("end_time")
if start_time and end_time and start_time > end_time:
raise ValueError("start_time must be <= end_time")
return values
class FeedbackSourceBase(BaseModel):
type: ClassVar[str]
metadata: Optional[Dict[str, Any]] = None
class Config:
frozen = True
class APIFeedbackSource(FeedbackSourceBase):
"""API feedback source."""
type: ClassVar[str] = "api"
class ModelFeedbackSource(FeedbackSourceBase):
"""Model feedback source."""
type: ClassVar[str] = "model"
class FeedbackSourceType(Enum):
"""Feedback source type."""
API = "api"
"""General feedback submitted from the API."""
MODEL = "model"
"""Model-assisted feedback."""
class FeedbackBase(BaseModel):
"""Feedback schema."""
created_at: datetime = Field(default_factory=datetime.utcnow)
"""The time the feedback was created."""
modified_at: datetime = Field(default_factory=datetime.utcnow)
"""The time the feedback was last modified."""
run_id: UUID
"""The associated run ID this feedback is logged for."""
key: str
"""The metric name, tag, or aspect to provide feedback on."""
score: Union[float, int, bool, None] = None
"""Value or score to assign the run."""
value: Union[float, int, bool, str, dict, None] = None
"""The display value, tag or other value for the feedback if not a metric."""
comment: Optional[str] = None
"""Comment or explanation for the feedback."""
correction: Union[str, dict, None] = None
"""Correction for the run."""
feedback_source: Optional[
Union[APIFeedbackSource, ModelFeedbackSource, Mapping[str, Any]]
] = None
"""The source of the feedback."""
class Config:
frozen = True
class FeedbackCreate(FeedbackBase):
"""Schema used for creating feedback."""
id: UUID = Field(default_factory=uuid4)
feedback_source: APIFeedbackSource
"""The source of the feedback."""
class Feedback(FeedbackBase):
"""Schema for getting feedback."""
id: UUID
feedback_source: Optional[Dict] = None
"""The source of the feedback. In this case"""
class ListFeedbackQueryParams(BaseModel):
"""Query Params for listing feedbacks."""
run: Optional[Sequence[UUID]] = None
limit: int = 100
offset: int = 0
class Config:
"""Config for query params."""
extra = "forbid"
frozen = True