diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index 46eb141932..38c10627a3 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect import re from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from pydantic_v1 import Extra, root_validator @@ -119,6 +119,15 @@ class BaseQAWithSourcesChain(Chain, ABC): values["combine_documents_chain"] = values.pop("combine_document_chain") return values + def _split_sources(self, answer: str) -> Tuple[str, str]: + """Split sources from answer.""" + if re.search(r"SOURCES:\s", answer): + answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2] + sources = re.split(r"\n", sources)[0] + else: + sources = "" + return answer, sources + @abstractmethod def _get_docs( self, @@ -145,10 +154,7 @@ class BaseQAWithSourcesChain(Chain, ABC): answer = self.combine_documents_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **inputs ) - if re.search(r"SOURCES:\s", answer): - answer, sources = re.split(r"SOURCES:\s", answer) - else: - sources = "" + answer, sources = self._split_sources(answer) result: Dict[str, Any] = { self.answer_key: answer, self.sources_answer_key: sources, @@ -182,10 +188,7 @@ class BaseQAWithSourcesChain(Chain, ABC): answer = await self.combine_documents_chain.arun( input_documents=docs, callbacks=_run_manager.get_child(), **inputs ) - if re.search(r"SOURCES:\s", answer): - answer, sources = re.split(r"SOURCES:\s", answer) - else: - sources = "" + answer, sources = self._split_sources(answer) result: Dict[str, Any] = { self.answer_key: answer, self.sources_answer_key: sources, diff --git a/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py new file mode 100644 index 0000000000..e69d9b5cd1 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py @@ -0,0 +1,71 @@ +import pytest + +from langchain.chains.qa_with_sources.base import QAWithSourcesChain +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.mark.parametrize( + "text,answer,sources", + [ + ( + "This Agreement is governed by English law.\nSOURCES: 28-pl", + "This Agreement is governed by English law.\n", + "28-pl", + ), + ( + "This Agreement is governed by English law.\n" + "SOURCES: 28-pl\n\n" + "QUESTION: Which state/country's law governs the interpretation of the " + "contract?\n" + "FINAL ANSWER: This Agreement is governed by English law.\n" + "SOURCES: 28-pl", + "This Agreement is governed by English law.\n", + "28-pl", + ), + ( + "The president did not mention Michael Jackson in the provided content.\n" + "SOURCES: \n\n" + "Note: Since the content provided does not contain any information about " + "Michael Jackson, there are no sources to cite for this specific question.", + "The president did not mention Michael Jackson in the provided content.\n", + "", + ), + # The following text was generated by gpt-3.5-turbo + ( + "To diagnose the problem, please answer the following questions and send " + "them in one message to IT:\nA1. Are you connected to the office network? " + "VPN will not work from the office network.\nA2. Are you sure about your " + "login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n" + "A4. When was the last time you used the company VPN?\n" + "SOURCES: 1\n\n" + "ALTERNATIVE OPTION: Another option is to run the VPN in CLI, but keep in " + "mind that DNS settings may not work and there may be a need for manual " + "modification of the local resolver or /etc/hosts and/or ~/.ssh/config " + "files to be able to connect to machines in the company. With the " + "appropriate packages installed, the only thing needed to establish " + "a connection is to run the command:\nsudo openvpn --config config.ovpn" + "\n\nWe will be asked for a username and password - provide the login " + "details, the same ones that have been used so far for VPN connection, " + "connecting to the company's WiFi, or printers (in the Warsaw office)." + "\n\nFinally, just use the VPN connection.\n" + "SOURCES: 2\n\n" + "ALTERNATIVE OPTION (for Windows): Download the" + "OpenVPN client application version 2.6 or newer from the official " + "website: https://openvpn.net/community-downloads/\n" + "SOURCES: 3", + "To diagnose the problem, please answer the following questions and send " + "them in one message to IT:\nA1. Are you connected to the office network? " + "VPN will not work from the office network.\nA2. Are you sure about your " + "login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n" + "A4. When was the last time you used the company VPN?\n", + "1", + ), + ], +) +def test_spliting_answer_into_answer_and_sources( + text: str, answer: str, sources: str +) -> None: + qa_chain = QAWithSourcesChain.from_llm(FakeLLM()) + generated_answer, generated_sources = qa_chain._split_sources(text) + assert generated_answer == answer + assert generated_sources == sources