mirror of https://github.com/hwchase17/langchain
core[patch]: remove requests (#19891)
Removes required usage of `requests` from `langchain-core`, all of which has been deprecated. - removes Tracer V1 implementations - removes old `try_load_from_hub` github-based hub implementations Removal done in a way where imports will still succeed, and usage will fail with a `RuntimeError`.pull/19700/head
parent
d5a2ff58e9
commit
f0d5b59962
@ -1,187 +1,14 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
def get_headers(*args: Any, **kwargs: Any) -> Any:
|
||||
raise RuntimeError(
|
||||
"get_headers for LangChainTracerV1 is no longer supported. "
|
||||
"Please use LangChainTracer instead."
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import get_buffer_string
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
Run,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionV1,
|
||||
TracerSessionV1Base,
|
||||
)
|
||||
from langchain_core.utils import raise_for_status_with_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_headers() -> Dict[str, Any]:
|
||||
"""Get the headers for the LangChain API."""
|
||||
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
return headers
|
||||
|
||||
|
||||
def _get_endpoint() -> str:
|
||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="LangChainTracer", removal="0.2.0")
|
||||
class LangChainTracerV1(BaseTracer):
|
||||
"""Implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self.session: Optional[TracerSessionV1] = None
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = get_headers()
|
||||
|
||||
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
session = self.session or self.load_default_session()
|
||||
if not isinstance(session, TracerSessionV1):
|
||||
raise ValueError(
|
||||
"LangChainTracerV1 is not compatible with"
|
||||
f" session of type {type(session)}"
|
||||
)
|
||||
|
||||
if run.run_type == "llm":
|
||||
if "prompts" in run.inputs:
|
||||
prompts = run.inputs["prompts"]
|
||||
elif "messages" in run.inputs:
|
||||
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
|
||||
else:
|
||||
raise ValueError("No prompts found in LLM run inputs")
|
||||
return LLMRun(
|
||||
uuid=str(run.id) if run.id else None, # type: ignore[arg-type]
|
||||
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time, # type: ignore[arg-type]
|
||||
extra=run.extra,
|
||||
execution_order=run.execution_order,
|
||||
child_execution_order=run.child_execution_order,
|
||||
serialized=run.serialized, # type: ignore[arg-type]
|
||||
session_id=session.id,
|
||||
error=run.error,
|
||||
prompts=prompts,
|
||||
response=run.outputs if run.outputs else None, # type: ignore[arg-type]
|
||||
)
|
||||
if run.run_type == "chain":
|
||||
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
|
||||
return ChainRun(
|
||||
uuid=str(run.id) if run.id else None, # type: ignore[arg-type]
|
||||
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time, # type: ignore[arg-type]
|
||||
execution_order=run.execution_order,
|
||||
child_execution_order=run.child_execution_order,
|
||||
serialized=run.serialized, # type: ignore[arg-type]
|
||||
session_id=session.id,
|
||||
inputs=run.inputs,
|
||||
outputs=run.outputs,
|
||||
error=run.error,
|
||||
extra=run.extra,
|
||||
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
|
||||
child_chain_runs=[
|
||||
run for run in child_runs if isinstance(run, ChainRun)
|
||||
],
|
||||
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
|
||||
)
|
||||
if run.run_type == "tool":
|
||||
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
|
||||
return ToolRun(
|
||||
uuid=str(run.id) if run.id else None, # type: ignore[arg-type]
|
||||
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time, # type: ignore[arg-type]
|
||||
execution_order=run.execution_order,
|
||||
child_execution_order=run.child_execution_order,
|
||||
serialized=run.serialized, # type: ignore[arg-type]
|
||||
session_id=session.id,
|
||||
action=str(run.serialized),
|
||||
tool_input=run.inputs.get("input", ""),
|
||||
output=None if run.outputs is None else run.outputs.get("output"),
|
||||
error=run.error,
|
||||
extra=run.extra,
|
||||
child_chain_runs=[
|
||||
run for run in child_runs if isinstance(run, ChainRun)
|
||||
],
|
||||
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
|
||||
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
|
||||
)
|
||||
raise ValueError(f"Unknown run type: {run.run_type}")
|
||||
|
||||
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
if isinstance(run, Run):
|
||||
v1_run = self._convert_to_v1_run(run)
|
||||
else:
|
||||
v1_run = run
|
||||
if isinstance(v1_run, LLMRun):
|
||||
endpoint = f"{self._endpoint}/llm-runs"
|
||||
elif isinstance(v1_run, ChainRun):
|
||||
endpoint = f"{self._endpoint}/chain-runs"
|
||||
else:
|
||||
endpoint = f"{self._endpoint}/tool-runs"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
data=v1_run.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to persist run: {e}")
|
||||
|
||||
def _persist_session(
|
||||
self, session_create: TracerSessionV1Base
|
||||
) -> Union[TracerSessionV1, TracerSession]:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{self._endpoint}/sessions",
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create session, using default session: {e}")
|
||||
session = TracerSessionV1(id=1, **session_create.dict())
|
||||
return session
|
||||
|
||||
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
|
||||
"""Load a session from the tracer."""
|
||||
try:
|
||||
url = f"{self._endpoint}/sessions"
|
||||
if session_name:
|
||||
url += f"?name={session_name}"
|
||||
r = requests.get(url, headers=self._headers)
|
||||
|
||||
tracer_session = TracerSessionV1(**r.json()[0])
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logger.warning(
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSessionV1(id=1)
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self._load_session("default")
|
||||
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any:
|
||||
raise RuntimeError(
|
||||
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
|
||||
)
|
||||
|
@ -1,68 +1,13 @@
|
||||
"""Utilities for loading configurations from langchain_core-hub."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Callable, Optional, Set, TypeVar, Union
|
||||
from urllib.parse import urljoin
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
|
||||
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
|
||||
LANGCHAINHUB_REPO = "https://raw.githubusercontent.com/hwchase17/langchain-hub/"
|
||||
URL_BASE = os.environ.get(
|
||||
"LANGCHAIN_HUB_URL_BASE",
|
||||
LANGCHAINHUB_REPO + "{ref}/",
|
||||
)
|
||||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.1.30",
|
||||
removal="0.2",
|
||||
message=(
|
||||
"Using the hwchase17/langchain-hub "
|
||||
"repo for prompts is deprecated. Please use "
|
||||
"https://smith.langchain.com/hub instead."
|
||||
),
|
||||
)
|
||||
def try_load_from_hub(
|
||||
path: Union[str, Path],
|
||||
loader: Callable[[str], T],
|
||||
valid_prefix: str,
|
||||
valid_suffixes: Set[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[T]:
|
||||
"""Load configuration from hub. Returns None if path is not a hub path."""
|
||||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
|
||||
return None
|
||||
ref, remote_path_str = match.groups()
|
||||
ref = ref[1:] if ref else DEFAULT_REF
|
||||
remote_path = Path(remote_path_str)
|
||||
if remote_path.parts[0] != valid_prefix:
|
||||
return None
|
||||
if remote_path.suffix[1:] not in valid_suffixes:
|
||||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
||||
|
||||
# Using Path with URLs is not recommended, because on Windows
|
||||
# the backslash is used as the path separator, which can cause issues
|
||||
# when working with URLs that use forward slashes as the path separator.
|
||||
# Instead, use PurePosixPath to ensure that forward slashes are used as the
|
||||
# path separator, regardless of the operating system.
|
||||
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
|
||||
if not full_url.startswith(LANGCHAINHUB_REPO):
|
||||
raise ValueError(f"Invalid hub path: {path}")
|
||||
|
||||
r = requests.get(full_url, timeout=5)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = Path(tmpdirname) / remote_path.name
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return loader(str(file), **kwargs)
|
||||
) -> Any:
|
||||
raise RuntimeError(
|
||||
"Loading from the deprecated github-based Hub is no longer supported. "
|
||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead."
|
||||
)
|
||||
|
@ -1,562 +0,0 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.base import BaseTracer, TracerException
|
||||
from langchain_core.tracers.langchain_v1 import (
|
||||
ChainRun,
|
||||
LangChainTracerV1,
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerSessionV1,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run, TracerSessionV1Base
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return TracerSessionV1(
|
||||
id=TEST_SESSION_ID, name=session_name, start_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
def new_session(name: Optional[str] = None) -> TracerSessionV1:
|
||||
"""Create a new tracing session."""
|
||||
return TracerSessionV1(
|
||||
id=TEST_SESSION_ID,
|
||||
name=name or "default",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID})
|
||||
|
||||
|
||||
def load_default_session() -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return TracerSessionV1(
|
||||
id=TEST_SESSION_ID, name="default", start_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1:
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV1()
|
||||
return tracer
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
if isinstance(run, Run):
|
||||
with pytest.MonkeyPatch().context() as m:
|
||||
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
m.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV1()
|
||||
tracer.load_default_session = load_default_session # type: ignore
|
||||
run = tracer._convert_to_v1_run(run)
|
||||
self.runs.append(run)
|
||||
|
||||
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def new_session(self, name: Optional[str] = None) -> TracerSessionV1:
|
||||
"""Create a new tracing session."""
|
||||
return new_session(name)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
def _compare_run_with_error(run: Any, expected_run: Any) -> None:
|
||||
received = run.dict()
|
||||
received_err = received.pop("error")
|
||||
expected = expected_run.dict()
|
||||
expected_err = expected.pop("error")
|
||||
assert received == expected
|
||||
assert expected_err in received_err
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
|
||||
)
|
||||
compare_run = LLMRun(
|
||||
uuid=str(run_managers[0].run_id),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED_CHAT,
|
||||
prompts=["Human: "],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ChainRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ToolRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "tool"},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{'name': 'tool'}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
for _ in range(10):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "tool"},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{'name': 'tool'}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(tool_uuid),
|
||||
error=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs[0] == compare_run
|
||||
assert tracer.runs == [compare_run] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ToolRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "tool"},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{'name': 'tool'}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tracer_session_v1() -> TracerSessionV1:
|
||||
return TracerSessionV1(id=2, name="Sample session")
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_convert_run(
|
||||
lang_chain_tracer_v1: LangChainTracerV1,
|
||||
sample_tracer_session_v1: TracerSessionV1,
|
||||
) -> None:
|
||||
"""Test converting a run to a V1 run."""
|
||||
llm_run = Run( # type: ignore[call-arg]
|
||||
id="57a08cc4-73d2-4236-8370-549099d07fad", # type: ignore[arg-type]
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
session_id=TEST_SESSION_ID,
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]).dict(),
|
||||
serialized={},
|
||||
extra={},
|
||||
run_type="llm",
|
||||
)
|
||||
chain_run = Run(
|
||||
id="57a08cc4-73d2-4236-8371-549099d07fad", # type: ignore[arg-type]
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
child_runs=[llm_run],
|
||||
extra={},
|
||||
run_type="chain",
|
||||
)
|
||||
|
||||
tool_run = Run(
|
||||
id="57a08cc4-73d2-4236-8372-549099d07fad", # type: ignore[arg-type]
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
inputs={"input": "test"},
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
outputs=None,
|
||||
serialized={},
|
||||
child_runs=[],
|
||||
extra={},
|
||||
run_type="tool",
|
||||
)
|
||||
|
||||
expected_llm_run = LLMRun( # type: ignore[call-arg]
|
||||
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
session_id=2,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
serialized={},
|
||||
extra={},
|
||||
)
|
||||
|
||||
expected_chain_run = ChainRun( # type: ignore[call-arg]
|
||||
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
session_id=2,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
child_llm_runs=[expected_llm_run],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
expected_tool_run = ToolRun( # type: ignore[call-arg]
|
||||
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=2,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
tool_input="test",
|
||||
action="{}",
|
||||
serialized={},
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
lang_chain_tracer_v1.session = sample_tracer_session_v1
|
||||
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run)
|
||||
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, LLMRun)
|
||||
assert isinstance(converted_chain_run, ChainRun)
|
||||
assert isinstance(converted_tool_run, ToolRun)
|
||||
assert converted_llm_run == expected_llm_run
|
||||
assert converted_tool_run == expected_tool_run
|
||||
assert converted_chain_run == expected_chain_run
|
@ -1,106 +0,0 @@
|
||||
"""Test the functionality of loading from langchain-hub."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from unittest.mock import Mock
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from langchain_core.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mocked_responses() -> Iterable[responses.RequestsMock]:
|
||||
"""Fixture mocking requests.get."""
|
||||
with responses.RequestsMock() as rsps:
|
||||
yield rsps
|
||||
|
||||
|
||||
def test_non_hub_path() -> None:
|
||||
"""Test that a non-hub path returns None."""
|
||||
path = "chains/some_path"
|
||||
loader = Mock()
|
||||
valid_suffixes = {"suffix"}
|
||||
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
|
||||
|
||||
assert result is None
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
def test_invalid_prefix() -> None:
|
||||
"""Test that a hub path with an invalid prefix returns None."""
|
||||
path = "lc://agents/some_path"
|
||||
loader = Mock()
|
||||
valid_suffixes = {"suffix"}
|
||||
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
|
||||
|
||||
assert result is None
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
def test_invalid_suffix() -> None:
|
||||
"""Test that a hub path with an invalid suffix raises an error."""
|
||||
path = "lc://chains/path.invalid"
|
||||
loader = Mock()
|
||||
valid_suffixes = {"json"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=f"Unsupported file type, must be one of {valid_suffixes}."
|
||||
):
|
||||
try_load_from_hub(path, loader, "chains", valid_suffixes)
|
||||
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ref", [None, "v0.3"])
|
||||
def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
|
||||
"""Test that a valid hub path is loaded correctly with and without a ref."""
|
||||
path = "chains/path/chain.json"
|
||||
lc_path_prefix = f"lc{('@' + ref) if ref else ''}://"
|
||||
valid_suffixes = {"json"}
|
||||
body = json.dumps({"foo": "bar"})
|
||||
ref = ref or DEFAULT_REF
|
||||
|
||||
file_contents = None
|
||||
|
||||
def loader(file_path: str) -> None:
|
||||
nonlocal file_contents
|
||||
assert file_contents is None
|
||||
file_contents = Path(file_path).read_text()
|
||||
|
||||
mocked_responses.get( # type: ignore
|
||||
urljoin(URL_BASE.format(ref=ref), path),
|
||||
body=body,
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
try_load_from_hub(f"{lc_path_prefix}{path}", loader, "chains", valid_suffixes)
|
||||
assert file_contents == body
|
||||
|
||||
|
||||
def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
|
||||
"""Test that a failed request raises an error."""
|
||||
path = "chains/path/chain.json"
|
||||
loader = Mock()
|
||||
|
||||
mocked_responses.get( # type: ignore
|
||||
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
|
||||
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
def test_path_traversal() -> None:
|
||||
"""Test that a path traversal attack is prevented."""
|
||||
path = "lc://chains/../../../../../../../../../it.json"
|
||||
loader = Mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try_load_from_hub(path, loader, "chains", {"json"})
|
@ -1,52 +1,3 @@
|
||||
# Susceptible to arbitrary code execution: https://github.com/langchain-ai/langchain/issues/4849
|
||||
import importlib.util
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from langchain.prompts.loading import load_prompt
|
||||
|
||||
import yaml
|
||||
from langchain.prompts.loading import load_prompt_from_config, try_load_from_hub
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
|
||||
|
||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Unified method for loading a prompt from LangChainHub or local file system."""
|
||||
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
|
||||
):
|
||||
return hub_result
|
||||
else:
|
||||
return _load_prompt_from_file(path)
|
||||
|
||||
|
||||
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to a Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
elif file_path.suffix == ".py":
|
||||
spec = importlib.util.spec_from_loader(
|
||||
"prompt", loader=None, origin=str(file_path)
|
||||
)
|
||||
if spec is None:
|
||||
raise ValueError("could not load spec")
|
||||
helper = importlib.util.module_from_spec(spec)
|
||||
with open(file_path, "rb") as f:
|
||||
exec(f.read(), helper.__dict__)
|
||||
if not isinstance(helper.PROMPT, BasePromptTemplate):
|
||||
raise ValueError("Did not get object of type BasePromptTemplate.")
|
||||
return helper.PROMPT
|
||||
else:
|
||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||
# Load the prompt from the config now.
|
||||
return load_prompt_from_config(config)
|
||||
__all__ = ["load_prompt"]
|
||||
|
Loading…
Reference in New Issue