"""Test LLM Math functionality.""" import json import pytest from langchain import LLMChain from langchain.chains.api.base import APIChain from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.requests import RequestsWrapper from tests.unit_tests.llms.fake_llm import FakeLLM class FakeRequestsChain(RequestsWrapper): """Fake requests chain just for testing purposes.""" output: str def run(self, url: str) -> str: """Just return the specified output.""" return self.output @pytest.fixture def test_api_data() -> dict: """Fake api data to use for testing.""" api_docs = """ This API endpoint will search the notes for a user. Endpoint: https://thisapidoesntexist.com GET /api/notes Query parameters: q | string | The search term for notes """ return { "api_docs": api_docs, "question": "Search for notes containing langchain", "api_url": "https://thisapidoesntexist.com/api/notes?q=langchain", "api_response": json.dumps( { "success": True, "results": [{"id": 1, "content": "Langchain is awesome!"}], } ), "api_summary": "There is 1 note about langchain.", } @pytest.fixture def fake_llm_api_chain(test_api_data: dict) -> APIChain: """Fake LLM API chain for testing.""" TEST_API_DOCS = test_api_data["api_docs"] TEST_QUESTION = test_api_data["question"] TEST_URL = test_api_data["api_url"] TEST_API_RESPONSE = test_api_data["api_response"] TEST_API_SUMMARY = test_api_data["api_summary"] api_url_query_prompt = API_URL_PROMPT.format( api_docs=TEST_API_DOCS, question=TEST_QUESTION ) api_response_prompt = API_RESPONSE_PROMPT.format( api_docs=TEST_API_DOCS, question=TEST_QUESTION, api_url=TEST_URL, api_response=TEST_API_RESPONSE, ) queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY} fake_llm = FakeLLM(queries=queries) api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT) api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT) requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE) return APIChain( api_request_chain=api_request_chain, api_answer_chain=api_answer_chain, requests_wrapper=requests_wrapper, api_docs=TEST_API_DOCS, ) def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None: """Test simple question that needs API access.""" question = test_api_data["question"] output = fake_llm_api_chain.run(question) assert output == test_api_data["api_summary"]