mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix: change the chatgpt plugin retriever metadata format (#5920)
the current implement put the doc itself as the metadata, but the document chatgpt plugin retriever returned already has a `metadata` field, it's better to use that instead. the original code will throw the following exception when using `RetrievalQAWithSourcesChain`, becuse it can not find the field `metadata`: ```python Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main) Document prompt requires documents to have metadata variables: ['source']. Received document with missing metadata: ['source']. File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/base.py", line 27, in format_document raise ValueError( File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 65, in <listcomp> doc_strings = [format_document(doc, self.document_prompt) for doc in docs] File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 65, in _get_inputs doc_strings = [format_document(doc, self.document_prompt) for doc in docs] File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 85, in combine_docs inputs = self._get_inputs(docs, **kwargs) File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/base.py", line 84, in _call output, extra_return_dict = self.combine_docs( File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/base.py", line 140, in __call__ raise e ``` Additionally, the `metadata` filed in the `chatgpt plugin retriever` have these fileds by default: ```json { "source": "file", //email, file or chat "source_id": "filename.docx", // the filename "url": "", ... } ``` so, we should set `source_id` to `source` in the langchain metadata. ```python metadata = d.pop("metadata", d) if(metadata.get("source_id")): metadata["source"] = metadata.pop("source_id") ``` #### Who can review? @dev2049 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> --------- Co-authored-by: wangjie <wangjie@htffund.com>
This commit is contained in:
parent
e67b26eee9
commit
50d9c7d5a4
@ -28,7 +28,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
docs = []
|
docs = []
|
||||||
for d in results:
|
for d in results:
|
||||||
content = d.pop("text")
|
content = d.pop("text")
|
||||||
docs.append(Document(page_content=content, metadata=d))
|
metadata = d.pop("metadata", d)
|
||||||
|
if metadata.get("source_id"):
|
||||||
|
metadata["source"] = metadata.pop("source_id")
|
||||||
|
docs.append(Document(page_content=content, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
@ -48,7 +51,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
docs = []
|
docs = []
|
||||||
for d in results:
|
for d in results:
|
||||||
content = d.pop("text")
|
content = d.pop("text")
|
||||||
docs.append(Document(page_content=content, metadata=d))
|
metadata = d.pop("metadata", d)
|
||||||
|
if metadata.get("source_id"):
|
||||||
|
metadata["source"] = metadata.pop("source_id")
|
||||||
|
docs.append(Document(page_content=content, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _create_request(self, query: str) -> tuple[str, dict, dict]:
|
def _create_request(self, query: str) -> tuple[str, dict, dict]:
|
||||||
|
Loading…
Reference in New Issue
Block a user