From 9d8ab28837341667a2234deacfc7e114aa2a28f7 Mon Sep 17 00:00:00 2001 From: Altay Sansal Date: Sat, 15 Apr 2023 23:07:53 -0500 Subject: [PATCH] Add `top_k` and `filter` fields to `ChatGPTPluginRetriever` (#2852) This allows to adjust the number of results to retrieve and filter documents based on metadata. --------- Co-authored-by: Altay Sansal --- .../retrievers/chatgpt_plugin_retriever.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index d2d41327..436ff9c2 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -1,4 +1,6 @@ -from typing import List, Optional +from __future__ import annotations + +from typing import List import aiohttp import requests @@ -10,7 +12,9 @@ from langchain.schema import BaseRetriever, Document class ChatGPTPluginRetriever(BaseRetriever, BaseModel): url: str bearer_token: str - aiosession: Optional[aiohttp.ClientSession] = None + top_k: int = 3 + filter: dict | None = None + aiosession: aiohttp.ClientSession | None = None class Config: """Configuration for this pydantic object.""" @@ -18,14 +22,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): arbitrary_types_allowed = True def get_relevant_documents(self, query: str) -> List[Document]: - response = requests.post( - f"{self.url}/query", - json={"queries": [{"query": query}]}, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self.bearer_token}", - }, - ) + url, json, headers = self._create_request(query) + response = requests.post(url, json=json, headers=headers) results = response.json()["results"][0]["results"] docs = [] for d in results: @@ -34,12 +32,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): return docs async def aget_relevant_documents(self, query: str) -> List[Document]: - url = f"{self.url}/query" - json = {"queries": [{"query": query}]} - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.bearer_token}", - } + url, json, headers = self._create_request(query) if not self.aiosession: async with aiohttp.ClientSession() as session: @@ -57,3 +50,20 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): content = d.pop("text") docs.append(Document(page_content=content, metadata=d)) return docs + + def _create_request(self, query: str) -> tuple[str, dict, dict]: + url = f"{self.url}/query" + json = { + "queries": [ + { + "query": query, + "filter": self.filter, + "top_k": self.top_k, + } + ] + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.bearer_token}", + } + return url, json, headers