[Community][minor]: Added prompt governance in pebblo_retrieval (#24874)

Title: [pebblo_retrieval] Identifying entities in prompts given in
PebbloRetrievalQA leading to prompt governance
Description: Implemented identification of entities in the prompt using
Pebblo prompt governance API.
Issue: NA
Dependencies: NA
Add tests and docs: NA
pull/24881/head
Nishan Jain 2 months ago committed by GitHub
parent a6add89bd4
commit b00c0fc558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -8,7 +8,7 @@ import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import requests # type: ignore
from langchain.chains.base import Chain
@ -37,6 +37,7 @@ from langchain_community.chains.pebblo_retrieval.utilities import (
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
PROMPT_GOV_URL,
PROMPT_URL,
get_runtime,
)
@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain):
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not"""
def _call(
self,
@ -102,6 +105,8 @@ class PebbloRetrievalQA(Chain):
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})
_, prompt_entities = self._check_prompt_validity(question)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
@ -133,7 +138,12 @@ class PebbloRetrievalQA(Chain):
for doc in docs
if isinstance(doc, Document)
],
"prompt": {"data": question},
"prompt": {
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": self.enable_prompt_gov,
},
"response": {
"data": answer,
},
@ -144,6 +154,7 @@ class PebbloRetrievalQA(Chain):
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
@ -175,6 +186,9 @@ class PebbloRetrievalQA(Chain):
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
_, prompt_entities = self._check_prompt_validity(question)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
@ -513,6 +527,66 @@ class PebbloRetrievalQA(Chain):
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt: bool = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
prompt_entities: dict = {"entities": {}, "entityCount": 0}
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)
logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
return is_valid_prompt, prompt_entities
@classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__

@ -133,7 +133,10 @@ class Context(BaseModel):
class Prompt(BaseModel):
data: str
data: Optional[Union[list, str]]
entityCount: Optional[int]
entities: Optional[dict]
prompt_gov_enabled: Optional[bool]
class Qa(BaseModel):

@ -15,6 +15,7 @@ CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
PROMPT_URL = "/v1/prompt"
PROMPT_GOV_URL = "/v1/prompt/governance"
APP_DISCOVER_URL = "/v1/app/discover"

Loading…
Cancel
Save