@ -8,7 +8,7 @@ import inspect
import json
import json
import logging
import logging
from http import HTTPStatus
from http import HTTPStatus
from typing import Any , Dict , List , Optional
from typing import Any , Dict , List , Optional , Tuple
import requests # type: ignore
import requests # type: ignore
from langchain . chains . base import Chain
from langchain . chains . base import Chain
@ -37,6 +37,7 @@ from langchain_community.chains.pebblo_retrieval.utilities import (
CLASSIFIER_URL ,
CLASSIFIER_URL ,
PEBBLO_CLOUD_URL ,
PEBBLO_CLOUD_URL ,
PLUGIN_VERSION ,
PLUGIN_VERSION ,
PROMPT_GOV_URL ,
PROMPT_URL ,
PROMPT_URL ,
get_runtime ,
get_runtime ,
)
)
@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain):
""" Flag to check if discover payload has been sent. """
""" Flag to check if discover payload has been sent. """
_prompt_sent : bool = False #: :meta private:
_prompt_sent : bool = False #: :meta private:
""" Flag to check if prompt payload has been sent. """
""" Flag to check if prompt payload has been sent. """
enable_prompt_gov : bool = True #: :meta private:
""" Flag to check if prompt governance is enabled or not """
def _call (
def _call (
self ,
self ,
@ -102,6 +105,8 @@ class PebbloRetrievalQA(Chain):
question = inputs [ self . input_key ]
question = inputs [ self . input_key ]
auth_context = inputs . get ( self . auth_context_key , { } )
auth_context = inputs . get ( self . auth_context_key , { } )
semantic_context = inputs . get ( self . semantic_context_key , { } )
semantic_context = inputs . get ( self . semantic_context_key , { } )
_ , prompt_entities = self . _check_prompt_validity ( question )
accepts_run_manager = (
accepts_run_manager = (
" run_manager " in inspect . signature ( self . _get_docs ) . parameters
" run_manager " in inspect . signature ( self . _get_docs ) . parameters
)
)
@ -133,7 +138,12 @@ class PebbloRetrievalQA(Chain):
for doc in docs
for doc in docs
if isinstance ( doc , Document )
if isinstance ( doc , Document )
] ,
] ,
" prompt " : { " data " : question } ,
" prompt " : {
" data " : question ,
" entities " : prompt_entities . get ( " entities " , { } ) ,
" entityCount " : prompt_entities . get ( " entityCount " , 0 ) ,
" prompt_gov_enabled " : self . enable_prompt_gov ,
} ,
" response " : {
" response " : {
" data " : answer ,
" data " : answer ,
} ,
} ,
@ -144,6 +154,7 @@ class PebbloRetrievalQA(Chain):
else [ ] ,
else [ ] ,
" classifier_location " : self . classifier_location ,
" classifier_location " : self . classifier_location ,
}
}
qa_payload = Qa ( * * qa )
qa_payload = Qa ( * * qa )
self . _send_prompt ( qa_payload )
self . _send_prompt ( qa_payload )
@ -175,6 +186,9 @@ class PebbloRetrievalQA(Chain):
accepts_run_manager = (
accepts_run_manager = (
" run_manager " in inspect . signature ( self . _aget_docs ) . parameters
" run_manager " in inspect . signature ( self . _aget_docs ) . parameters
)
)
_ , prompt_entities = self . _check_prompt_validity ( question )
if accepts_run_manager :
if accepts_run_manager :
docs = await self . _aget_docs (
docs = await self . _aget_docs (
question , auth_context , semantic_context , run_manager = _run_manager
question , auth_context , semantic_context , run_manager = _run_manager
@ -513,6 +527,66 @@ class PebbloRetrievalQA(Chain):
logger . warning ( " API key is missing for sending prompt to Pebblo cloud. " )
logger . warning ( " API key is missing for sending prompt to Pebblo cloud. " )
raise NameError ( " API key is missing for sending prompt to Pebblo cloud. " )
raise NameError ( " API key is missing for sending prompt to Pebblo cloud. " )
def _check_prompt_validity ( self , question : str ) - > Tuple [ bool , Dict [ str , Any ] ] :
"""
Check the validity of the given prompt using a remote classification service .
This method sends a prompt to a remote classifier service and return entities
present in prompt or not .
Args :
question ( str ) : The prompt question to be validated .
Returns :
bool : True if the prompt is valid ( does not contain deny list entities ) ,
False otherwise .
dict : The entities present in the prompt
"""
headers = {
" Accept " : " application/json " ,
" Content-Type " : " application/json " ,
}
prompt_payload = { " prompt " : question }
is_valid_prompt : bool = True
prompt_gov_api_url = f " { self . classifier_url } { PROMPT_GOV_URL } "
pebblo_resp = None
prompt_entities : dict = { " entities " : { } , " entityCount " : 0 }
if self . classifier_location == " local " :
try :
pebblo_resp = requests . post (
prompt_gov_api_url ,
headers = headers ,
json = prompt_payload ,
timeout = 20 ,
)
logger . debug ( " prompt-payload: %s " , prompt_payload )
logger . debug (
" send_prompt[local]: request url %s , body %s len %s \
response status % s body % s " ,
pebblo_resp . request . url ,
str ( pebblo_resp . request . body ) ,
str (
len (
pebblo_resp . request . body if pebblo_resp . request . body else [ ]
)
) ,
str ( pebblo_resp . status_code ) ,
pebblo_resp . json ( ) ,
)
logger . debug ( f " pebblo_resp.json() { pebblo_resp . json ( ) } " )
prompt_entities [ " entities " ] = pebblo_resp . json ( ) . get ( " entities " , { } )
prompt_entities [ " entityCount " ] = pebblo_resp . json ( ) . get (
" entityCount " , 0
)
except requests . exceptions . RequestException :
logger . warning ( " Unable to reach pebblo server. " )
except Exception as e :
logger . warning ( " An Exception caught in _send_discover: local %s " , e )
return is_valid_prompt , prompt_entities
@classmethod
@classmethod
def get_chain_details ( cls , llm : BaseLanguageModel , * * kwargs ) : # type: ignore
def get_chain_details ( cls , llm : BaseLanguageModel , * * kwargs ) : # type: ignore
llm_dict = llm . __dict__
llm_dict = llm . __dict__