Add improved sources splitting in BaseQAWithSourcesChain (#8716)

## Type:
Improvement

---

## Description:
Running QAWithSourcesChain sometimes raises ValueError as mentioned in
issue #7184:
```
ValueError: too many values to unpack (expected 2)
Traceback:

    response = qa({"question": pregunta}, return_only_outputs=True)
File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 166, in __call__
    raise e
File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 160, in __call__
    self._call(inputs, run_manager=run_manager)
File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\qa_with_sources\base.py", line 132, in _call
    answer, sources = re.split(r"SOURCES:\s", answer)
```
This is due to LLM model generating subsequent question, answer and
sources, that is complement in a similar form as below:
```
<final_answer>
SOURCES: <sources>
QUESTION: <new_or_repeated_question>
FINAL ANSWER: <new_or_repeated_final_answer>
SOURCES: <new_or_repeated_sources>
```
It leads the following line
```
 re.split(r"SOURCES:\s", answer)
```
to return more than 2 elements and result in ValueError. The simple fix
is to split also with "QUESTION:\s" and take the first two elements:
```
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
```

Sometimes LLM might also generate some other texts, like alternative
answers in a form:
```
<final_answer_1>
SOURCES: <sources>

<final_answer_2>
SOURCES: <sources>

<final_answer_3>
SOURCES: <sources>
```
In such cases it is the best to split previously obtained sources with
new line:
```
sources = re.split(r"\n", sources.lstrip())[0]
```



---

## Issue:
Resolves #7184

---

## Maintainer:
@baskaryan
This commit is contained in:
Jakub Kuciński 2023-08-16 22:30:15 +02:00 committed by GitHub
parent a3c79b1909
commit 8bebc9206f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 9 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from pydantic_v1 import Extra, root_validator
@ -119,6 +119,15 @@ class BaseQAWithSourcesChain(Chain, ABC):
values["combine_documents_chain"] = values.pop("combine_document_chain")
return values
def _split_sources(self, answer: str) -> Tuple[str, str]:
"""Split sources from answer."""
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
sources = re.split(r"\n", sources)[0]
else:
sources = ""
return answer, sources
@abstractmethod
def _get_docs(
self,
@ -145,10 +154,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:
sources = ""
answer, sources = self._split_sources(answer)
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
@ -182,10 +188,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
answer = await self.combine_documents_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:
sources = ""
answer, sources = self._split_sources(answer)
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,

View File

@ -0,0 +1,71 @@
import pytest
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.parametrize(
"text,answer,sources",
[
(
"This Agreement is governed by English law.\nSOURCES: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"This Agreement is governed by English law.\n"
"SOURCES: 28-pl\n\n"
"QUESTION: Which state/country's law governs the interpretation of the "
"contract?\n"
"FINAL ANSWER: This Agreement is governed by English law.\n"
"SOURCES: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"The president did not mention Michael Jackson in the provided content.\n"
"SOURCES: \n\n"
"Note: Since the content provided does not contain any information about "
"Michael Jackson, there are no sources to cite for this specific question.",
"The president did not mention Michael Jackson in the provided content.\n",
"",
),
# The following text was generated by gpt-3.5-turbo
(
"To diagnose the problem, please answer the following questions and send "
"them in one message to IT:\nA1. Are you connected to the office network? "
"VPN will not work from the office network.\nA2. Are you sure about your "
"login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n"
"A4. When was the last time you used the company VPN?\n"
"SOURCES: 1\n\n"
"ALTERNATIVE OPTION: Another option is to run the VPN in CLI, but keep in "
"mind that DNS settings may not work and there may be a need for manual "
"modification of the local resolver or /etc/hosts and/or ~/.ssh/config "
"files to be able to connect to machines in the company. With the "
"appropriate packages installed, the only thing needed to establish "
"a connection is to run the command:\nsudo openvpn --config config.ovpn"
"\n\nWe will be asked for a username and password - provide the login "
"details, the same ones that have been used so far for VPN connection, "
"connecting to the company's WiFi, or printers (in the Warsaw office)."
"\n\nFinally, just use the VPN connection.\n"
"SOURCES: 2\n\n"
"ALTERNATIVE OPTION (for Windows): Download the"
"OpenVPN client application version 2.6 or newer from the official "
"website: https://openvpn.net/community-downloads/\n"
"SOURCES: 3",
"To diagnose the problem, please answer the following questions and send "
"them in one message to IT:\nA1. Are you connected to the office network? "
"VPN will not work from the office network.\nA2. Are you sure about your "
"login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n"
"A4. When was the last time you used the company VPN?\n",
"1",
),
],
)
def test_spliting_answer_into_answer_and_sources(
text: str, answer: str, sources: str
) -> None:
qa_chain = QAWithSourcesChain.from_llm(FakeLLM())
generated_answer, generated_sources = qa_chain._split_sources(text)
assert generated_answer == answer
assert generated_sources == sources