diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 628a7b3603..499e4e2531 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional @@ -116,8 +117,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: docs = self._get_docs(inputs) answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs) - if "SOURCES: " in answer: - answer, sources = answer.split("SOURCES: ") + if re.search(r"SOURCES:\s", answer): + answer, sources = re.split(r"SOURCES:\s", answer) else: sources = "" result: Dict[str, Any] = {