APIChain add restrictions to domains (CVE-2023-32786) (#12747)

* Restrict the chain to specific domains by default
* This is a breaking change, but it will fail loudly upon object
instantiation -- so there should be no silent errors for users
* Resolves CVE-2023-32786
pull/12364/head
Eugene Yurtsev 10 months ago committed by GitHub
parent 4421ba46d7
commit b1caae62fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because one or more lines are too long

@ -1,7 +1,8 @@
"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -16,6 +17,38 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.utilities.requests import TextRequestsWrapper
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
"""Extract the scheme + domain from a given URL.
Args:
url (str): The input URL.
Returns:
return a 2-tuple of scheme and domain
"""
parsed_uri = urlparse(url)
return parsed_uri.scheme, parsed_uri.netloc
def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
"""Check if a URL is in the allowed domains.
Args:
url (str): The input URL.
limit_to_domains (Sequence[str]): The allowed domains.
Returns:
bool: True if the URL is in the allowed domains, False otherwise.
"""
scheme, domain = _extract_scheme_and_domain(url)
for allowed_domain in limit_to_domains:
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
if scheme == allowed_scheme and domain == allowed_domain:
return True
return False
class APIChain(Chain):
"""Chain that makes API calls and summarizes the responses to answer a question.
@ -40,6 +73,19 @@ class APIChain(Chain):
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
limit_to_domains: Optional[Sequence[str]]
"""Use to limit the domains that can be accessed by the API chain.
* For example, to limit to just the domain `https://www.example.com`, set
`limit_to_domains=["https://www.example.com"]`.
* The default value is an empty tuple, which means that no domains are
allowed by default. By design this will raise an error on instantiation.
* Use a None if you want to allow all domains by default -- this is not
recommended for security reasons, as it would allow malicious users to
make requests to arbitrary URLS including internal APIs accessible from
the server.
"""
@property
def input_keys(self) -> List[str]:
@ -68,6 +114,21 @@ class APIChain(Chain):
)
return values
@root_validator(pre=True)
def validate_limit_to_domains(cls, values: Dict) -> Dict:
"""Check that allowed domains are valid."""
if "limit_to_domains" not in values:
raise ValueError(
"You must specify a list of domains to limit access using "
"`limit_to_domains`"
)
if not values["limit_to_domains"] and values["limit_to_domains"] is not None:
raise ValueError(
"Please provide a list of domains to limit access using "
"`limit_to_domains`."
)
return values
@root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""Check that api answer prompt expects the right variables."""
@ -93,6 +154,12 @@ class APIChain(Chain):
)
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
api_url = api_url.strip()
if self.limit_to_domains and not _check_in_allowed_domain(
api_url, self.limit_to_domains
):
raise ValueError(
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
)
api_response = self.requests_wrapper.get(api_url)
_run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose
@ -122,6 +189,12 @@ class APIChain(Chain):
api_url, color="green", end="\n", verbose=self.verbose
)
api_url = api_url.strip()
if self.limit_to_domains and not _check_in_allowed_domain(
api_url, self.limit_to_domains
):
raise ValueError(
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
)
api_response = await self.requests_wrapper.aget(api_url)
await _run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose
@ -143,6 +216,7 @@ class APIChain(Chain):
headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
limit_to_domains: Optional[Sequence[str]] = tuple(),
**kwargs: Any,
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
@ -154,6 +228,7 @@ class APIChain(Chain):
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
limit_to_domains=limit_to_domains,
**kwargs,
)

@ -22,8 +22,7 @@ class FakeRequestsChain(TextRequestsWrapper):
return self.output
@pytest.fixture
def test_api_data() -> dict:
def get_test_api_data() -> dict:
"""Fake api data to use for testing."""
api_docs = """
This API endpoint will search the notes for a user.
@ -48,39 +47,59 @@ def test_api_data() -> dict:
}
@pytest.fixture
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
def get_api_chain(**kwargs: Any) -> APIChain:
"""Fake LLM API chain for testing."""
TEST_API_DOCS = test_api_data["api_docs"]
TEST_QUESTION = test_api_data["question"]
TEST_URL = test_api_data["api_url"]
TEST_API_RESPONSE = test_api_data["api_response"]
TEST_API_SUMMARY = test_api_data["api_summary"]
data = get_test_api_data()
test_api_docs = data["api_docs"]
test_question = data["question"]
test_url = data["api_url"]
test_api_response = data["api_response"]
test_api_summary = data["api_summary"]
api_url_query_prompt = API_URL_PROMPT.format(
api_docs=TEST_API_DOCS, question=TEST_QUESTION
api_docs=test_api_docs, question=test_question
)
api_response_prompt = API_RESPONSE_PROMPT.format(
api_docs=TEST_API_DOCS,
question=TEST_QUESTION,
api_url=TEST_URL,
api_response=TEST_API_RESPONSE,
api_docs=test_api_docs,
question=test_question,
api_url=test_url,
api_response=test_api_response,
)
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
queries = {api_url_query_prompt: test_url, api_response_prompt: test_api_summary}
fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
requests_wrapper = FakeRequestsChain(output=test_api_response)
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS,
api_docs=test_api_docs,
**kwargs,
)
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
def test_api_question() -> None:
"""Test simple question that needs API access."""
question = test_api_data["question"]
output = fake_llm_api_chain.run(question)
assert output == test_api_data["api_summary"]
with pytest.raises(ValueError):
get_api_chain()
with pytest.raises(ValueError):
get_api_chain(limit_to_domains=tuple())
# All domains allowed (not advised)
api_chain = get_api_chain(limit_to_domains=None)
data = get_test_api_data()
assert api_chain.run(data["question"]) == data["api_summary"]
# Use a domain that's allowed
api_chain = get_api_chain(
limit_to_domains=["https://thisapidoesntexist.com/api/notes?q=langchain"]
)
# Attempts to make a request against a domain that's not allowed
assert api_chain.run(data["question"]) == data["api_summary"]
# Use domains that are not valid
api_chain = get_api_chain(limit_to_domains=["h", "*"])
with pytest.raises(ValueError):
# Attempts to make a request against a domain that's not allowed
assert api_chain.run(data["question"]) == data["api_summary"]

Loading…
Cancel
Save