|
|
|
@ -10,7 +10,6 @@ from typing import (
|
|
|
|
|
Callable,
|
|
|
|
|
Dict,
|
|
|
|
|
Iterator,
|
|
|
|
|
List,
|
|
|
|
|
Mapping,
|
|
|
|
|
Optional,
|
|
|
|
|
Sequence,
|
|
|
|
@ -26,7 +25,8 @@ from requests import Response
|
|
|
|
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
|
|
|
|
|
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
|
|
|
from langchain.callbacks.tracers.schemas import Run, TracerSession
|
|
|
|
|
from langchain.callbacks.tracers.schemas import Run as TracerRun
|
|
|
|
|
from langchain.callbacks.tracers.schemas import TracerSession
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.client.models import (
|
|
|
|
|
APIFeedbackSource,
|
|
|
|
@ -54,6 +54,10 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Run(TracerRun):
|
|
|
|
|
id: UUID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_link_stem(url: str) -> str:
|
|
|
|
|
scheme = urlsplit(url).scheme
|
|
|
|
|
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
|
|
|
@ -75,7 +79,6 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
|
|
|
|
|
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
|
|
|
|
api_url: str = Field(default="http://localhost:1984", env="LANGCHAIN_ENDPOINT")
|
|
|
|
|
tenant_id: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
@ -87,31 +90,8 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"API key must be provided when using hosted LangChain+ API"
|
|
|
|
|
)
|
|
|
|
|
tenant_id = values.get("tenant_id")
|
|
|
|
|
if not tenant_id:
|
|
|
|
|
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
|
|
|
|
|
api_url, api_key
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
|
|
|
|
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
|
|
|
|
"""Get the tenant ID from the seeded tenant."""
|
|
|
|
|
url = f"{api_url}/tenants"
|
|
|
|
|
headers = {"x-api-key": api_key} if api_key else {}
|
|
|
|
|
response = requests.get(url, headers=headers)
|
|
|
|
|
try:
|
|
|
|
|
raise_for_status_with_text(response)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unable to get default tenant ID. Please manually provide."
|
|
|
|
|
) from e
|
|
|
|
|
results: List[dict] = response.json()
|
|
|
|
|
if len(results) == 0:
|
|
|
|
|
raise ValueError("No seeded tenant found")
|
|
|
|
|
return results[0]["id"]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _get_session_name(
|
|
|
|
|
session_name: Optional[str],
|
|
|
|
@ -149,18 +129,10 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
headers["x-api-key"] = self.api_key
|
|
|
|
|
return headers
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def query_params(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get the headers for the API request."""
|
|
|
|
|
return {"tenant_id": self.tenant_id}
|
|
|
|
|
|
|
|
|
|
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
|
|
|
|
"""Make a GET request."""
|
|
|
|
|
query_params = self.query_params
|
|
|
|
|
if params:
|
|
|
|
|
query_params.update(params)
|
|
|
|
|
return requests.get(
|
|
|
|
|
f"{self.api_url}{path}", headers=self._headers, params=query_params
|
|
|
|
|
f"{self.api_url}{path}", headers=self._headers, params=params
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def upload_dataframe(
|
|
|
|
@ -192,7 +164,6 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
"input_keys": ",".join(input_keys),
|
|
|
|
|
"output_keys": ",".join(output_keys),
|
|
|
|
|
"description": description,
|
|
|
|
|
"tenant_id": self.tenant_id,
|
|
|
|
|
}
|
|
|
|
|
response = requests.post(
|
|
|
|
|
self.api_url + "/datasets/upload",
|
|
|
|
@ -244,7 +215,7 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
) -> TracerSession:
|
|
|
|
|
"""Read a session from the LangChain+ API."""
|
|
|
|
|
path = "/sessions"
|
|
|
|
|
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
|
|
|
|
params: Dict[str, Any] = {"limit": 1}
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
path += f"/{session_id}"
|
|
|
|
|
elif session_name is not None:
|
|
|
|
@ -291,7 +262,6 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
) -> Dataset:
|
|
|
|
|
"""Create a dataset in the LangChain+ API."""
|
|
|
|
|
dataset = DatasetCreate(
|
|
|
|
|
tenant_id=self.tenant_id,
|
|
|
|
|
name=dataset_name,
|
|
|
|
|
description=description,
|
|
|
|
|
)
|
|
|
|
@ -309,7 +279,7 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
|
|
|
|
) -> Dataset:
|
|
|
|
|
path = "/datasets"
|
|
|
|
|
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
|
|
|
|
params: Dict[str, Any] = {"limit": 1}
|
|
|
|
|
if dataset_id is not None:
|
|
|
|
|
path += f"/{dataset_id}"
|
|
|
|
|
elif dataset_name is not None:
|
|
|
|
|