From 50d9c7d5a4f44c6120b96ad9e19e615748471154 Mon Sep 17 00:00:00 2001 From: JaysonAlbert Date: Fri, 16 Jun 2023 13:04:45 +0800 Subject: [PATCH] Fix: change the chatgpt plugin retriever metadata format (#5920) the current implement put the doc itself as the metadata, but the document chatgpt plugin retriever returned already has a `metadata` field, it's better to use that instead. the original code will throw the following exception when using `RetrievalQAWithSourcesChain`, becuse it can not find the field `metadata`: ```python Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main) Document prompt requires documents to have metadata variables: ['source']. Received document with missing metadata: ['source']. File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/base.py", line 27, in format_document raise ValueError( File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 65, in doc_strings = [format_document(doc, self.document_prompt) for doc in docs] File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 65, in _get_inputs doc_strings = [format_document(doc, self.document_prompt) for doc in docs] File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 85, in combine_docs inputs = self._get_inputs(docs, **kwargs) File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/base.py", line 84, in _call output, extra_return_dict = self.combine_docs( File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/base.py", line 140, in __call__ raise e ``` Additionally, the `metadata` filed in the `chatgpt plugin retriever` have these fileds by default: ```json { "source": "file", //email, file or chat "source_id": "filename.docx", // the filename "url": "", ... } ``` so, we should set `source_id` to `source` in the langchain metadata. ```python metadata = d.pop("metadata", d) if(metadata.get("source_id")): metadata["source"] = metadata.pop("source_id") ``` #### Who can review? @dev2049 --------- Co-authored-by: wangjie --- langchain/retrievers/chatgpt_plugin_retriever.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index 4655d5ea..e0f3f13c 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -28,7 +28,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): docs = [] for d in results: content = d.pop("text") - docs.append(Document(page_content=content, metadata=d)) + metadata = d.pop("metadata", d) + if metadata.get("source_id"): + metadata["source"] = metadata.pop("source_id") + docs.append(Document(page_content=content, metadata=metadata)) return docs async def aget_relevant_documents(self, query: str) -> List[Document]: @@ -48,7 +51,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): docs = [] for d in results: content = d.pop("text") - docs.append(Document(page_content=content, metadata=d)) + metadata = d.pop("metadata", d) + if metadata.get("source_id"): + metadata["source"] = metadata.pop("source_id") + docs.append(Document(page_content=content, metadata=metadata)) return docs def _create_request(self, query: str) -> tuple[str, dict, dict]: