Add top_k and filter fields to ChatGPTPluginRetriever (#2852)

This allows to adjust the number of results to retrieve and filter
documents based on metadata.

---------

Co-authored-by: Altay Sansal <altay.sansal@tgs.com>
This commit is contained in:
Altay Sansal 2023-04-15 23:07:53 -05:00 committed by GitHub
parent 4ffc58e07b
commit 9d8ab28837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,6 @@
from typing import List, Optional from __future__ import annotations
from typing import List
import aiohttp import aiohttp
import requests import requests
@ -10,7 +12,9 @@ from langchain.schema import BaseRetriever, Document
class ChatGPTPluginRetriever(BaseRetriever, BaseModel): class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
url: str url: str
bearer_token: str bearer_token: str
aiosession: Optional[aiohttp.ClientSession] = None top_k: int = 3
filter: dict | None = None
aiosession: aiohttp.ClientSession | None = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -18,14 +22,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]: def get_relevant_documents(self, query: str) -> List[Document]:
response = requests.post( url, json, headers = self._create_request(query)
f"{self.url}/query", response = requests.post(url, json=json, headers=headers)
json={"queries": [{"query": query}]},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.bearer_token}",
},
)
results = response.json()["results"][0]["results"] results = response.json()["results"][0]["results"]
docs = [] docs = []
for d in results: for d in results:
@ -34,12 +32,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
return docs return docs
async def aget_relevant_documents(self, query: str) -> List[Document]: async def aget_relevant_documents(self, query: str) -> List[Document]:
url = f"{self.url}/query" url, json, headers = self._create_request(query)
json = {"queries": [{"query": query}]}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.bearer_token}",
}
if not self.aiosession: if not self.aiosession:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -57,3 +50,20 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
content = d.pop("text") content = d.pop("text")
docs.append(Document(page_content=content, metadata=d)) docs.append(Document(page_content=content, metadata=d))
return docs return docs
def _create_request(self, query: str) -> tuple[str, dict, dict]:
url = f"{self.url}/query"
json = {
"queries": [
{
"query": query,
"filter": self.filter,
"top_k": self.top_k,
}
]
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.bearer_token}",
}
return url, json, headers