diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index 4655d5ea..e0f3f13c 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -28,7 +28,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): docs = [] for d in results: content = d.pop("text") - docs.append(Document(page_content=content, metadata=d)) + metadata = d.pop("metadata", d) + if metadata.get("source_id"): + metadata["source"] = metadata.pop("source_id") + docs.append(Document(page_content=content, metadata=metadata)) return docs async def aget_relevant_documents(self, query: str) -> List[Document]: @@ -48,7 +51,10 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): docs = [] for d in results: content = d.pop("text") - docs.append(Document(page_content=content, metadata=d)) + metadata = d.pop("metadata", d) + if metadata.get("source_id"): + metadata["source"] = metadata.pop("source_id") + docs.append(Document(page_content=content, metadata=metadata)) return docs def _create_request(self, query: str) -> tuple[str, dict, dict]: