diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 02d3553c4b..5dcc15775d 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -8,7 +8,7 @@ import inspect import json import logging from http import HTTPStatus -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import requests # type: ignore from langchain.chains.base import Chain @@ -37,6 +37,7 @@ from langchain_community.chains.pebblo_retrieval.utilities import ( CLASSIFIER_URL, PEBBLO_CLOUD_URL, PLUGIN_VERSION, + PROMPT_GOV_URL, PROMPT_URL, get_runtime, ) @@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain): """Flag to check if discover payload has been sent.""" _prompt_sent: bool = False #: :meta private: """Flag to check if prompt payload has been sent.""" + enable_prompt_gov: bool = True #: :meta private: + """Flag to check if prompt governance is enabled or not""" def _call( self, @@ -102,6 +105,8 @@ class PebbloRetrievalQA(Chain): question = inputs[self.input_key] auth_context = inputs.get(self.auth_context_key, {}) semantic_context = inputs.get(self.semantic_context_key, {}) + _, prompt_entities = self._check_prompt_validity(question) + accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) @@ -133,7 +138,12 @@ class PebbloRetrievalQA(Chain): for doc in docs if isinstance(doc, Document) ], - "prompt": {"data": question}, + "prompt": { + "data": question, + "entities": prompt_entities.get("entities", {}), + "entityCount": prompt_entities.get("entityCount", 0), + "prompt_gov_enabled": self.enable_prompt_gov, + }, "response": { "data": answer, }, @@ -144,6 +154,7 @@ class PebbloRetrievalQA(Chain): else [], "classifier_location": self.classifier_location, } + qa_payload = Qa(**qa) self._send_prompt(qa_payload) @@ -175,6 +186,9 @@ class PebbloRetrievalQA(Chain): accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) + + _, prompt_entities = self._check_prompt_validity(question) + if accepts_run_manager: docs = await self._aget_docs( question, auth_context, semantic_context, run_manager=_run_manager @@ -513,6 +527,66 @@ class PebbloRetrievalQA(Chain): logger.warning("API key is missing for sending prompt to Pebblo cloud.") raise NameError("API key is missing for sending prompt to Pebblo cloud.") + def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]: + """ + Check the validity of the given prompt using a remote classification service. + + This method sends a prompt to a remote classifier service and return entities + present in prompt or not. + + Args: + question (str): The prompt question to be validated. + + Returns: + bool: True if the prompt is valid (does not contain deny list entities), + False otherwise. + dict: The entities present in the prompt + """ + + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + prompt_payload = {"prompt": question} + is_valid_prompt: bool = True + prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}" + pebblo_resp = None + prompt_entities: dict = {"entities": {}, "entityCount": 0} + if self.classifier_location == "local": + try: + pebblo_resp = requests.post( + prompt_gov_api_url, + headers=headers, + json=prompt_payload, + timeout=20, + ) + + logger.debug("prompt-payload: %s", prompt_payload) + logger.debug( + "send_prompt[local]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_resp.request.url, + str(pebblo_resp.request.body), + str( + len( + pebblo_resp.request.body if pebblo_resp.request.body else [] + ) + ), + str(pebblo_resp.status_code), + pebblo_resp.json(), + ) + logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}") + prompt_entities["entities"] = pebblo_resp.json().get("entities", {}) + prompt_entities["entityCount"] = pebblo_resp.json().get( + "entityCount", 0 + ) + + except requests.exceptions.RequestException: + logger.warning("Unable to reach pebblo server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: local %s", e) + return is_valid_prompt, prompt_entities + @classmethod def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore llm_dict = llm.__dict__ diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/models.py b/libs/community/langchain_community/chains/pebblo_retrieval/models.py index 13a54537b9..e4fd7c6496 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/models.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/models.py @@ -133,7 +133,10 @@ class Context(BaseModel): class Prompt(BaseModel): - data: str + data: Optional[Union[list, str]] + entityCount: Optional[int] + entities: Optional[dict] + prompt_gov_enabled: Optional[bool] class Qa(BaseModel): diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py index 3056c8fae7..86218ad07b 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py @@ -15,6 +15,7 @@ CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000") PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") PROMPT_URL = "/v1/prompt" +PROMPT_GOV_URL = "/v1/prompt/governance" APP_DISCOVER_URL = "/v1/app/discover"