forked from Archives/langchain
Add top_k
and filter
fields to ChatGPTPluginRetriever
(#2852)
This allows to adjust the number of results to retrieve and filter documents based on metadata. --------- Co-authored-by: Altay Sansal <altay.sansal@tgs.com>
This commit is contained in:
parent
4ffc58e07b
commit
9d8ab28837
@ -1,4 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -10,7 +12,9 @@ from langchain.schema import BaseRetriever, Document
|
||||
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
url: str
|
||||
bearer_token: str
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
top_k: int = 3
|
||||
filter: dict | None = None
|
||||
aiosession: aiohttp.ClientSession | None = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -18,14 +22,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
response = requests.post(
|
||||
f"{self.url}/query",
|
||||
json={"queries": [{"query": query}]},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.bearer_token}",
|
||||
},
|
||||
)
|
||||
url, json, headers = self._create_request(query)
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
results = response.json()["results"][0]["results"]
|
||||
docs = []
|
||||
for d in results:
|
||||
@ -34,12 +32,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
url = f"{self.url}/query"
|
||||
json = {"queries": [{"query": query}]}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.bearer_token}",
|
||||
}
|
||||
url, json, headers = self._create_request(query)
|
||||
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@ -57,3 +50,20 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
content = d.pop("text")
|
||||
docs.append(Document(page_content=content, metadata=d))
|
||||
return docs
|
||||
|
||||
def _create_request(self, query: str) -> tuple[str, dict, dict]:
|
||||
url = f"{self.url}/query"
|
||||
json = {
|
||||
"queries": [
|
||||
{
|
||||
"query": query,
|
||||
"filter": self.filter,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
]
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.bearer_token}",
|
||||
}
|
||||
return url, json, headers
|
||||
|
Loading…
Reference in New Issue
Block a user