Add improved sources splitting in BaseQAWithSourcesChain (#8716)

## Type:
Improvement

---

## Description:
Running QAWithSourcesChain sometimes raises ValueError as mentioned in
issue #7184:
```
ValueError: too many values to unpack (expected 2)
Traceback:

    response = qa({"question": pregunta}, return_only_outputs=True)
File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 166, in __call__
    raise e
File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 160, in __call__
    self._call(inputs, run_manager=run_manager)
File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\qa_with_sources\base.py", line 132, in _call
    answer, sources = re.split(r"SOURCES:\s", answer)
```
This is due to LLM model generating subsequent question, answer and
sources, that is complement in a similar form as below:
```
<final_answer>
SOURCES: <sources>
QUESTION: <new_or_repeated_question>
FINAL ANSWER: <new_or_repeated_final_answer>
SOURCES: <new_or_repeated_sources>
```
It leads the following line
```
 re.split(r"SOURCES:\s", answer)
```
to return more than 2 elements and result in ValueError. The simple fix
is to split also with "QUESTION:\s" and take the first two elements:
```
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
```

Sometimes LLM might also generate some other texts, like alternative
answers in a form:
```
<final_answer_1>
SOURCES: <sources>

<final_answer_2>
SOURCES: <sources>

<final_answer_3>
SOURCES: <sources>
```
In such cases it is the best to split previously obtained sources with
new line:
```
sources = re.split(r"\n", sources.lstrip())[0]
```



---

## Issue:
Resolves #7184

---

## Maintainer:
@baskaryan
This commit is contained in:
Jakub Kuciński 2023-08-16 22:30:15 +02:00 committed by GitHub
parent a3c79b1909
commit 8bebc9206f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 9 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect import inspect
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from pydantic_v1 import Extra, root_validator from pydantic_v1 import Extra, root_validator
@ -119,6 +119,15 @@ class BaseQAWithSourcesChain(Chain, ABC):
values["combine_documents_chain"] = values.pop("combine_document_chain") values["combine_documents_chain"] = values.pop("combine_document_chain")
return values return values
def _split_sources(self, answer: str) -> Tuple[str, str]:
"""Split sources from answer."""
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
sources = re.split(r"\n", sources)[0]
else:
sources = ""
return answer, sources
@abstractmethod @abstractmethod
def _get_docs( def _get_docs(
self, self,
@ -145,10 +154,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
answer = self.combine_documents_chain.run( answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs input_documents=docs, callbacks=_run_manager.get_child(), **inputs
) )
if re.search(r"SOURCES:\s", answer): answer, sources = self._split_sources(answer)
answer, sources = re.split(r"SOURCES:\s", answer)
else:
sources = ""
result: Dict[str, Any] = { result: Dict[str, Any] = {
self.answer_key: answer, self.answer_key: answer,
self.sources_answer_key: sources, self.sources_answer_key: sources,
@ -182,10 +188,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
answer = await self.combine_documents_chain.arun( answer = await self.combine_documents_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs input_documents=docs, callbacks=_run_manager.get_child(), **inputs
) )
if re.search(r"SOURCES:\s", answer): answer, sources = self._split_sources(answer)
answer, sources = re.split(r"SOURCES:\s", answer)
else:
sources = ""
result: Dict[str, Any] = { result: Dict[str, Any] = {
self.answer_key: answer, self.answer_key: answer,
self.sources_answer_key: sources, self.sources_answer_key: sources,

View File

@ -0,0 +1,71 @@
import pytest
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.parametrize(
"text,answer,sources",
[
(
"This Agreement is governed by English law.\nSOURCES: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"This Agreement is governed by English law.\n"
"SOURCES: 28-pl\n\n"
"QUESTION: Which state/country's law governs the interpretation of the "
"contract?\n"
"FINAL ANSWER: This Agreement is governed by English law.\n"
"SOURCES: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"The president did not mention Michael Jackson in the provided content.\n"
"SOURCES: \n\n"
"Note: Since the content provided does not contain any information about "
"Michael Jackson, there are no sources to cite for this specific question.",
"The president did not mention Michael Jackson in the provided content.\n",
"",
),
# The following text was generated by gpt-3.5-turbo
(
"To diagnose the problem, please answer the following questions and send "
"them in one message to IT:\nA1. Are you connected to the office network? "
"VPN will not work from the office network.\nA2. Are you sure about your "
"login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n"
"A4. When was the last time you used the company VPN?\n"
"SOURCES: 1\n\n"
"ALTERNATIVE OPTION: Another option is to run the VPN in CLI, but keep in "
"mind that DNS settings may not work and there may be a need for manual "
"modification of the local resolver or /etc/hosts and/or ~/.ssh/config "
"files to be able to connect to machines in the company. With the "
"appropriate packages installed, the only thing needed to establish "
"a connection is to run the command:\nsudo openvpn --config config.ovpn"
"\n\nWe will be asked for a username and password - provide the login "
"details, the same ones that have been used so far for VPN connection, "
"connecting to the company's WiFi, or printers (in the Warsaw office)."
"\n\nFinally, just use the VPN connection.\n"
"SOURCES: 2\n\n"
"ALTERNATIVE OPTION (for Windows): Download the"
"OpenVPN client application version 2.6 or newer from the official "
"website: https://openvpn.net/community-downloads/\n"
"SOURCES: 3",
"To diagnose the problem, please answer the following questions and send "
"them in one message to IT:\nA1. Are you connected to the office network? "
"VPN will not work from the office network.\nA2. Are you sure about your "
"login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n"
"A4. When was the last time you used the company VPN?\n",
"1",
),
],
)
def test_spliting_answer_into_answer_and_sources(
text: str, answer: str, sources: str
) -> None:
qa_chain = QAWithSourcesChain.from_llm(FakeLLM())
generated_answer, generated_sources = qa_chain._split_sources(text)
assert generated_answer == answer
assert generated_sources == sources