mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Add top_k
and filter
fields to ChatGPTPluginRetriever
(#2852)
This allows to adjust the number of results to retrieve and filter documents based on metadata. --------- Co-authored-by: Altay Sansal <altay.sansal@tgs.com>
This commit is contained in:
parent
4ffc58e07b
commit
9d8ab28837
@ -1,4 +1,6 @@
|
|||||||
from typing import List, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -10,7 +12,9 @@ from langchain.schema import BaseRetriever, Document
|
|||||||
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||||
url: str
|
url: str
|
||||||
bearer_token: str
|
bearer_token: str
|
||||||
aiosession: Optional[aiohttp.ClientSession] = None
|
top_k: int = 3
|
||||||
|
filter: dict | None = None
|
||||||
|
aiosession: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -18,14 +22,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
response = requests.post(
|
url, json, headers = self._create_request(query)
|
||||||
f"{self.url}/query",
|
response = requests.post(url, json=json, headers=headers)
|
||||||
json={"queries": [{"query": query}]},
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.bearer_token}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
results = response.json()["results"][0]["results"]
|
results = response.json()["results"][0]["results"]
|
||||||
docs = []
|
docs = []
|
||||||
for d in results:
|
for d in results:
|
||||||
@ -34,12 +32,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
url = f"{self.url}/query"
|
url, json, headers = self._create_request(query)
|
||||||
json = {"queries": [{"query": query}]}
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.bearer_token}",
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.aiosession:
|
if not self.aiosession:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@ -57,3 +50,20 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
content = d.pop("text")
|
content = d.pop("text")
|
||||||
docs.append(Document(page_content=content, metadata=d))
|
docs.append(Document(page_content=content, metadata=d))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def _create_request(self, query: str) -> tuple[str, dict, dict]:
|
||||||
|
url = f"{self.url}/query"
|
||||||
|
json = {
|
||||||
|
"queries": [
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"filter": self.filter,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.bearer_token}",
|
||||||
|
}
|
||||||
|
return url, json, headers
|
||||||
|
Loading…
Reference in New Issue
Block a user