mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
6032a051e9
Update the V2 tracer to - use UUIDs instead of int's - load a tenant ID and use that when saving sessions
96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
"""Test the functionality of loading from langchain-hub."""
|
|
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
from unittest.mock import Mock
|
|
from urllib.parse import urljoin
|
|
|
|
import pytest
|
|
import responses
|
|
|
|
from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mocked_responses() -> Iterable[responses.RequestsMock]:
|
|
"""Fixture mocking requests.get."""
|
|
with responses.RequestsMock() as rsps:
|
|
yield rsps
|
|
|
|
|
|
def test_non_hub_path() -> None:
|
|
"""Test that a non-hub path returns None."""
|
|
path = "chains/some_path"
|
|
loader = Mock()
|
|
valid_suffixes = {"suffix"}
|
|
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
|
|
|
|
assert result is None
|
|
loader.assert_not_called()
|
|
|
|
|
|
def test_invalid_prefix() -> None:
|
|
"""Test that a hub path with an invalid prefix returns None."""
|
|
path = "lc://agents/some_path"
|
|
loader = Mock()
|
|
valid_suffixes = {"suffix"}
|
|
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
|
|
|
|
assert result is None
|
|
loader.assert_not_called()
|
|
|
|
|
|
def test_invalid_suffix() -> None:
|
|
"""Test that a hub path with an invalid suffix raises an error."""
|
|
path = "lc://chains/path.invalid"
|
|
loader = Mock()
|
|
valid_suffixes = {"json"}
|
|
|
|
with pytest.raises(ValueError, match="Unsupported file type."):
|
|
try_load_from_hub(path, loader, "chains", valid_suffixes)
|
|
|
|
loader.assert_not_called()
|
|
|
|
|
|
@pytest.mark.parametrize("ref", [None, "v0.3"])
|
|
def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
|
|
"""Test that a valid hub path is loaded correctly with and without a ref."""
|
|
path = "chains/path/chain.json"
|
|
lc_path_prefix = f"lc{('@' + ref) if ref else ''}://"
|
|
valid_suffixes = {"json"}
|
|
body = json.dumps({"foo": "bar"})
|
|
ref = ref or DEFAULT_REF
|
|
|
|
file_contents = None
|
|
|
|
def loader(file_path: str) -> None:
|
|
nonlocal file_contents
|
|
assert file_contents is None
|
|
file_contents = Path(file_path).read_text()
|
|
|
|
mocked_responses.get( # type: ignore
|
|
urljoin(URL_BASE.format(ref=ref), path),
|
|
body=body,
|
|
status=200,
|
|
content_type="application/json",
|
|
)
|
|
|
|
try_load_from_hub(f"{lc_path_prefix}{path}", loader, "chains", valid_suffixes)
|
|
assert file_contents == body
|
|
|
|
|
|
def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
|
|
"""Test that a failed request raises an error."""
|
|
path = "chains/path/chain.json"
|
|
loader = Mock()
|
|
|
|
mocked_responses.get( # type: ignore
|
|
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
|
|
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})
|
|
loader.assert_not_called()
|