[Community][minor]: Added prompt governance in pebblo_retrieval (#24874)

Title: [pebblo_retrieval] Identifying entities in prompts given in
PebbloRetrievalQA leading to prompt governance
Description: Implemented identification of entities in the prompt using
Pebblo prompt governance API.
Issue: NA
Dependencies: NA
Add tests and docs: NA
This commit is contained in:
Nishan Jain 2024-07-31 18:44:51 +05:30 committed by GitHub
parent a6add89bd4
commit b00c0fc558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 3 deletions

View File

@ -8,7 +8,7 @@ import inspect
import json import json
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
import requests # type: ignore import requests # type: ignore
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -37,6 +37,7 @@ from langchain_community.chains.pebblo_retrieval.utilities import (
CLASSIFIER_URL, CLASSIFIER_URL,
PEBBLO_CLOUD_URL, PEBBLO_CLOUD_URL,
PLUGIN_VERSION, PLUGIN_VERSION,
PROMPT_GOV_URL,
PROMPT_URL, PROMPT_URL,
get_runtime, get_runtime,
) )
@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain):
"""Flag to check if discover payload has been sent.""" """Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private: _prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent.""" """Flag to check if prompt payload has been sent."""
enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not"""
def _call( def _call(
self, self,
@ -102,6 +105,8 @@ class PebbloRetrievalQA(Chain):
question = inputs[self.input_key] question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {}) auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {}) semantic_context = inputs.get(self.semantic_context_key, {})
_, prompt_entities = self._check_prompt_validity(question)
accepts_run_manager = ( accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters "run_manager" in inspect.signature(self._get_docs).parameters
) )
@ -133,7 +138,12 @@ class PebbloRetrievalQA(Chain):
for doc in docs for doc in docs
if isinstance(doc, Document) if isinstance(doc, Document)
], ],
"prompt": {"data": question}, "prompt": {
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": self.enable_prompt_gov,
},
"response": { "response": {
"data": answer, "data": answer,
}, },
@ -144,6 +154,7 @@ class PebbloRetrievalQA(Chain):
else [], else [],
"classifier_location": self.classifier_location, "classifier_location": self.classifier_location,
} }
qa_payload = Qa(**qa) qa_payload = Qa(**qa)
self._send_prompt(qa_payload) self._send_prompt(qa_payload)
@ -175,6 +186,9 @@ class PebbloRetrievalQA(Chain):
accepts_run_manager = ( accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters "run_manager" in inspect.signature(self._aget_docs).parameters
) )
_, prompt_entities = self._check_prompt_validity(question)
if accepts_run_manager: if accepts_run_manager:
docs = await self._aget_docs( docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager question, auth_context, semantic_context, run_manager=_run_manager
@ -513,6 +527,66 @@ class PebbloRetrievalQA(Chain):
logger.warning("API key is missing for sending prompt to Pebblo cloud.") logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.") raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt: bool = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
prompt_entities: dict = {"entities": {}, "entityCount": 0}
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)
logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
return is_valid_prompt, prompt_entities
@classmethod @classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__ llm_dict = llm.__dict__

View File

@ -133,7 +133,10 @@ class Context(BaseModel):
class Prompt(BaseModel): class Prompt(BaseModel):
data: str data: Optional[Union[list, str]]
entityCount: Optional[int]
entities: Optional[dict]
prompt_gov_enabled: Optional[bool]
class Qa(BaseModel): class Qa(BaseModel):

View File

@ -15,6 +15,7 @@ CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
PROMPT_URL = "/v1/prompt" PROMPT_URL = "/v1/prompt"
PROMPT_GOV_URL = "/v1/prompt/governance"
APP_DISCOVER_URL = "/v1/app/discover" APP_DISCOVER_URL = "/v1/app/discover"