@ -1,7 +1,8 @@
""" Chain that makes API calls and summarizes the responses to answer a question. """
from __future__ import annotations
from typing import Any , Dict , List , Optional
from typing import Any , Dict , List , Optional , Sequence , Tuple
from urllib . parse import urlparse
from langchain . callbacks . manager import (
AsyncCallbackManagerForChainRun ,
@ -16,6 +17,38 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain . utilities . requests import TextRequestsWrapper
def _extract_scheme_and_domain ( url : str ) - > Tuple [ str , str ] :
""" Extract the scheme + domain from a given URL.
Args :
url ( str ) : The input URL .
Returns :
return a 2 - tuple of scheme and domain
"""
parsed_uri = urlparse ( url )
return parsed_uri . scheme , parsed_uri . netloc
def _check_in_allowed_domain ( url : str , limit_to_domains : Sequence [ str ] ) - > bool :
""" Check if a URL is in the allowed domains.
Args :
url ( str ) : The input URL .
limit_to_domains ( Sequence [ str ] ) : The allowed domains .
Returns :
bool : True if the URL is in the allowed domains , False otherwise .
"""
scheme , domain = _extract_scheme_and_domain ( url )
for allowed_domain in limit_to_domains :
allowed_scheme , allowed_domain = _extract_scheme_and_domain ( allowed_domain )
if scheme == allowed_scheme and domain == allowed_domain :
return True
return False
class APIChain ( Chain ) :
""" Chain that makes API calls and summarizes the responses to answer a question.
@ -40,6 +73,19 @@ class APIChain(Chain):
api_docs : str
question_key : str = " question " #: :meta private:
output_key : str = " output " #: :meta private:
limit_to_domains : Optional [ Sequence [ str ] ]
""" Use to limit the domains that can be accessed by the API chain.
* For example , to limit to just the domain ` https : / / www . example . com ` , set
` limit_to_domains = [ " https://www.example.com " ] ` .
* The default value is an empty tuple , which means that no domains are
allowed by default . By design this will raise an error on instantiation .
* Use a None if you want to allow all domains by default - - this is not
recommended for security reasons , as it would allow malicious users to
make requests to arbitrary URLS including internal APIs accessible from
the server .
"""
@property
def input_keys ( self ) - > List [ str ] :
@ -68,6 +114,21 @@ class APIChain(Chain):
)
return values
@root_validator ( pre = True )
def validate_limit_to_domains ( cls , values : Dict ) - > Dict :
""" Check that allowed domains are valid. """
if " limit_to_domains " not in values :
raise ValueError (
" You must specify a list of domains to limit access using "
" `limit_to_domains` "
)
if not values [ " limit_to_domains " ] and values [ " limit_to_domains " ] is not None :
raise ValueError (
" Please provide a list of domains to limit access using "
" `limit_to_domains`. "
)
return values
@root_validator ( pre = True )
def validate_api_answer_prompt ( cls , values : Dict ) - > Dict :
""" Check that api answer prompt expects the right variables. """
@ -93,6 +154,12 @@ class APIChain(Chain):
)
_run_manager . on_text ( api_url , color = " green " , end = " \n " , verbose = self . verbose )
api_url = api_url . strip ( )
if self . limit_to_domains and not _check_in_allowed_domain (
api_url , self . limit_to_domains
) :
raise ValueError (
f " { api_url } is not in the allowed domains: { self . limit_to_domains } "
)
api_response = self . requests_wrapper . get ( api_url )
_run_manager . on_text (
api_response , color = " yellow " , end = " \n " , verbose = self . verbose
@ -122,6 +189,12 @@ class APIChain(Chain):
api_url , color = " green " , end = " \n " , verbose = self . verbose
)
api_url = api_url . strip ( )
if self . limit_to_domains and not _check_in_allowed_domain (
api_url , self . limit_to_domains
) :
raise ValueError (
f " { api_url } is not in the allowed domains: { self . limit_to_domains } "
)
api_response = await self . requests_wrapper . aget ( api_url )
await _run_manager . on_text (
api_response , color = " yellow " , end = " \n " , verbose = self . verbose
@ -143,6 +216,7 @@ class APIChain(Chain):
headers : Optional [ dict ] = None ,
api_url_prompt : BasePromptTemplate = API_URL_PROMPT ,
api_response_prompt : BasePromptTemplate = API_RESPONSE_PROMPT ,
limit_to_domains : Optional [ Sequence [ str ] ] = tuple ( ) ,
* * kwargs : Any ,
) - > APIChain :
""" Load chain from just an LLM and the api docs. """
@ -154,6 +228,7 @@ class APIChain(Chain):
api_answer_chain = get_answer_chain ,
requests_wrapper = requests_wrapper ,
api_docs = api_docs ,
limit_to_domains = limit_to_domains ,
* * kwargs ,
)