@ -1,4 +1,6 @@
from typing import List , Optional
from __future__ import annotations
from typing import List
import aiohttp
import aiohttp
import requests
import requests
@ -10,7 +12,9 @@ from langchain.schema import BaseRetriever, Document
class ChatGPTPluginRetriever ( BaseRetriever , BaseModel ) :
class ChatGPTPluginRetriever ( BaseRetriever , BaseModel ) :
url : str
url : str
bearer_token : str
bearer_token : str
aiosession : Optional [ aiohttp . ClientSession ] = None
top_k : int = 3
filter : dict | None = None
aiosession : aiohttp . ClientSession | None = None
class Config :
class Config :
""" Configuration for this pydantic object. """
""" Configuration for this pydantic object. """
@ -18,14 +22,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True
arbitrary_types_allowed = True
def get_relevant_documents ( self , query : str ) - > List [ Document ] :
def get_relevant_documents ( self , query : str ) - > List [ Document ] :
response = requests . post (
url , json , headers = self . _create_request ( query )
f " { self . url } /query " ,
response = requests . post ( url , json = json , headers = headers )
json = { " queries " : [ { " query " : query } ] } ,
headers = {
" Content-Type " : " application/json " ,
" Authorization " : f " Bearer { self . bearer_token } " ,
} ,
)
results = response . json ( ) [ " results " ] [ 0 ] [ " results " ]
results = response . json ( ) [ " results " ] [ 0 ] [ " results " ]
docs = [ ]
docs = [ ]
for d in results :
for d in results :
@ -34,12 +32,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
return docs
return docs
async def aget_relevant_documents ( self , query : str ) - > List [ Document ] :
async def aget_relevant_documents ( self , query : str ) - > List [ Document ] :
url = f " { self . url } /query "
url , json , headers = self . _create_request ( query )
json = { " queries " : [ { " query " : query } ] }
headers = {
" Content-Type " : " application/json " ,
" Authorization " : f " Bearer { self . bearer_token } " ,
}
if not self . aiosession :
if not self . aiosession :
async with aiohttp . ClientSession ( ) as session :
async with aiohttp . ClientSession ( ) as session :
@ -57,3 +50,20 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
content = d . pop ( " text " )
content = d . pop ( " text " )
docs . append ( Document ( page_content = content , metadata = d ) )
docs . append ( Document ( page_content = content , metadata = d ) )
return docs
return docs
def _create_request ( self , query : str ) - > tuple [ str , dict , dict ] :
url = f " { self . url } /query "
json = {
" queries " : [
{
" query " : query ,
" filter " : self . filter ,
" top_k " : self . top_k ,
}
]
}
headers = {
" Content-Type " : " application/json " ,
" Authorization " : f " Bearer { self . bearer_token } " ,
}
return url , json , headers