mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Add improved sources splitting in BaseQAWithSourcesChain (#8716)
## Type: Improvement --- ## Description: Running QAWithSourcesChain sometimes raises ValueError as mentioned in issue #7184: ``` ValueError: too many values to unpack (expected 2) Traceback: response = qa({"question": pregunta}, return_only_outputs=True) File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 166, in __call__ raise e File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 160, in __call__ self._call(inputs, run_manager=run_manager) File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\qa_with_sources\base.py", line 132, in _call answer, sources = re.split(r"SOURCES:\s", answer) ``` This is due to LLM model generating subsequent question, answer and sources, that is complement in a similar form as below: ``` <final_answer> SOURCES: <sources> QUESTION: <new_or_repeated_question> FINAL ANSWER: <new_or_repeated_final_answer> SOURCES: <new_or_repeated_sources> ``` It leads the following line ``` re.split(r"SOURCES:\s", answer) ``` to return more than 2 elements and result in ValueError. The simple fix is to split also with "QUESTION:\s" and take the first two elements: ``` answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2] ``` Sometimes LLM might also generate some other texts, like alternative answers in a form: ``` <final_answer_1> SOURCES: <sources> <final_answer_2> SOURCES: <sources> <final_answer_3> SOURCES: <sources> ``` In such cases it is the best to split previously obtained sources with new line: ``` sources = re.split(r"\n", sources.lstrip())[0] ``` --- ## Issue: Resolves #7184 --- ## Maintainer: @baskaryan
This commit is contained in:
parent
a3c79b1909
commit
8bebc9206f
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic_v1 import Extra, root_validator
|
||||
|
||||
@ -119,6 +119,15 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
||||
return values
|
||||
|
||||
def _split_sources(self, answer: str) -> Tuple[str, str]:
|
||||
"""Split sources from answer."""
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
|
||||
sources = re.split(r"\n", sources)[0]
|
||||
else:
|
||||
sources = ""
|
||||
return answer, sources
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(
|
||||
self,
|
||||
@ -145,10 +154,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
sources = ""
|
||||
answer, sources = self._split_sources(answer)
|
||||
result: Dict[str, Any] = {
|
||||
self.answer_key: answer,
|
||||
self.sources_answer_key: sources,
|
||||
@ -182,10 +188,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
sources = ""
|
||||
answer, sources = self._split_sources(answer)
|
||||
result: Dict[str, Any] = {
|
||||
self.answer_key: answer,
|
||||
self.sources_answer_key: sources,
|
||||
|
@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
|
||||
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,answer,sources",
|
||||
[
|
||||
(
|
||||
"This Agreement is governed by English law.\nSOURCES: 28-pl",
|
||||
"This Agreement is governed by English law.\n",
|
||||
"28-pl",
|
||||
),
|
||||
(
|
||||
"This Agreement is governed by English law.\n"
|
||||
"SOURCES: 28-pl\n\n"
|
||||
"QUESTION: Which state/country's law governs the interpretation of the "
|
||||
"contract?\n"
|
||||
"FINAL ANSWER: This Agreement is governed by English law.\n"
|
||||
"SOURCES: 28-pl",
|
||||
"This Agreement is governed by English law.\n",
|
||||
"28-pl",
|
||||
),
|
||||
(
|
||||
"The president did not mention Michael Jackson in the provided content.\n"
|
||||
"SOURCES: \n\n"
|
||||
"Note: Since the content provided does not contain any information about "
|
||||
"Michael Jackson, there are no sources to cite for this specific question.",
|
||||
"The president did not mention Michael Jackson in the provided content.\n",
|
||||
"",
|
||||
),
|
||||
# The following text was generated by gpt-3.5-turbo
|
||||
(
|
||||
"To diagnose the problem, please answer the following questions and send "
|
||||
"them in one message to IT:\nA1. Are you connected to the office network? "
|
||||
"VPN will not work from the office network.\nA2. Are you sure about your "
|
||||
"login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n"
|
||||
"A4. When was the last time you used the company VPN?\n"
|
||||
"SOURCES: 1\n\n"
|
||||
"ALTERNATIVE OPTION: Another option is to run the VPN in CLI, but keep in "
|
||||
"mind that DNS settings may not work and there may be a need for manual "
|
||||
"modification of the local resolver or /etc/hosts and/or ~/.ssh/config "
|
||||
"files to be able to connect to machines in the company. With the "
|
||||
"appropriate packages installed, the only thing needed to establish "
|
||||
"a connection is to run the command:\nsudo openvpn --config config.ovpn"
|
||||
"\n\nWe will be asked for a username and password - provide the login "
|
||||
"details, the same ones that have been used so far for VPN connection, "
|
||||
"connecting to the company's WiFi, or printers (in the Warsaw office)."
|
||||
"\n\nFinally, just use the VPN connection.\n"
|
||||
"SOURCES: 2\n\n"
|
||||
"ALTERNATIVE OPTION (for Windows): Download the"
|
||||
"OpenVPN client application version 2.6 or newer from the official "
|
||||
"website: https://openvpn.net/community-downloads/\n"
|
||||
"SOURCES: 3",
|
||||
"To diagnose the problem, please answer the following questions and send "
|
||||
"them in one message to IT:\nA1. Are you connected to the office network? "
|
||||
"VPN will not work from the office network.\nA2. Are you sure about your "
|
||||
"login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n"
|
||||
"A4. When was the last time you used the company VPN?\n",
|
||||
"1",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_spliting_answer_into_answer_and_sources(
|
||||
text: str, answer: str, sources: str
|
||||
) -> None:
|
||||
qa_chain = QAWithSourcesChain.from_llm(FakeLLM())
|
||||
generated_answer, generated_sources = qa_chain._split_sources(text)
|
||||
assert generated_answer == answer
|
||||
assert generated_sources == sources
|
Loading…
Reference in New Issue
Block a user