Fix: change the chatgpt plugin retriever metadata format (#5920)

the current implement put the doc itself as the metadata, but the
document chatgpt plugin retriever returned already has a `metadata`
field, it's better to use that instead.

the original code will throw the following exception when using
`RetrievalQAWithSourcesChain`, becuse it can not find the field
`metadata`:

```python
Exception has occurred: ValueError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Document prompt requires documents to have metadata variables: ['source']. Received document with missing metadata: ['source'].
  File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/base.py", line 27, in format_document
    raise ValueError(
  File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 65, in <listcomp>
    doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
  File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 65, in _get_inputs
    doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
  File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/stuff.py", line 85, in combine_docs
    inputs = self._get_inputs(docs, **kwargs)
  File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/combine_documents/base.py", line 84, in _call
    output, extra_return_dict = self.combine_docs(
  File "/home/wangjie/anaconda3/envs/chatglm/lib/python3.10/site-packages/langchain/chains/base.py", line 140, in __call__
    raise e
```

Additionally, the `metadata` filed in the `chatgpt plugin retriever`
have these fileds by default:
```json
{
    "source":  "file",   //email, file or chat
    "source_id": "filename.docx", // the filename
    "url": "", 
    ...
}
```
so, we should set `source_id` to `source` in the langchain metadata.

```python
metadata = d.pop("metadata", d)
if(metadata.get("source_id")):
    metadata["source"] = metadata.pop("source_id")
```

#### Who can review?
@dev2049

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @vowelparrot

  VectorStores / Retrievers / Memory
  - @dev2049

 -->

---------

Co-authored-by: wangjie <wangjie@htffund.com>
This commit is contained in:
JaysonAlbert 2023-06-16 13:04:45 +08:00 committed by GitHub
parent e67b26eee9
commit 50d9c7d5a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,7 +28,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
docs = [] docs = []
for d in results: for d in results:
content = d.pop("text") content = d.pop("text")
docs.append(Document(page_content=content, metadata=d)) metadata = d.pop("metadata", d)
if metadata.get("source_id"):
metadata["source"] = metadata.pop("source_id")
docs.append(Document(page_content=content, metadata=metadata))
return docs return docs
async def aget_relevant_documents(self, query: str) -> List[Document]: async def aget_relevant_documents(self, query: str) -> List[Document]:
@ -48,7 +51,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
docs = [] docs = []
for d in results: for d in results:
content = d.pop("text") content = d.pop("text")
docs.append(Document(page_content=content, metadata=d)) metadata = d.pop("metadata", d)
if metadata.get("source_id"):
metadata["source"] = metadata.pop("source_id")
docs.append(Document(page_content=content, metadata=metadata))
return docs return docs
def _create_request(self, query: str) -> tuple[str, dict, dict]: def _create_request(self, query: str) -> tuple[str, dict, dict]: