langchain/tests/unit_tests/utilities/test_loading.py

94 lines
2.7 KiB
Python
Raw Normal View History

"""Test the functionality of loading from langchain-hub."""
import json
import re
from pathlib import Path
from typing import Iterable
from unittest.mock import Mock
from urllib.parse import urljoin
import pytest
import responses
from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
@pytest.fixture(autouse=True)
def mocked_responses() -> Iterable[responses.RequestsMock]:
"""Fixture mocking requests.get."""
with responses.RequestsMock() as rsps:
yield rsps
def test_non_hub_path() -> None:
"""Test that a non-hub path returns None."""
path = "chains/some_path"
loader = Mock()
valid_suffixes = {"suffix"}
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
assert result is None
loader.assert_not_called()
def test_invalid_prefix() -> None:
"""Test that a hub path with an invalid prefix returns None."""
path = "lc://agents/some_path"
loader = Mock()
valid_suffixes = {"suffix"}
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
assert result is None
loader.assert_not_called()
def test_invalid_suffix() -> None:
"""Test that a hub path with an invalid suffix raises an error."""
path = "lc://chains/path.invalid"
loader = Mock()
valid_suffixes = {"json"}
with pytest.raises(ValueError, match="Unsupported file type."):
try_load_from_hub(path, loader, "chains", valid_suffixes)
loader.assert_not_called()
@pytest.mark.parametrize("ref", [None, "v0.3"])
def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
"""Test that a valid hub path is loaded correctly with and without a ref."""
path = "chains/path/chain.json"
lc_path_prefix = f"lc{('@' + ref) if ref else ''}://"
valid_suffixes = {"json"}
body = json.dumps({"foo": "bar"})
ref = ref or DEFAULT_REF
file_contents = None
def loader(file_path: str) -> None:
nonlocal file_contents
assert file_contents is None
file_contents = Path(file_path).read_text()
mocked_responses.get(
urljoin(URL_BASE.format(ref=ref), path),
body=body,
status=200,
content_type="application/json",
)
try_load_from_hub(f"{lc_path_prefix}{path}", loader, "chains", valid_suffixes)
assert file_contents == body
def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
"""Test that a failed request raises an error."""
path = "chains/path/chain.json"
loader = Mock()
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500)
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})
loader.assert_not_called()