mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
[Community][minor]: Added prompt governance in pebblo_retrieval (#24874)
Title: [pebblo_retrieval] Identifying entities in prompts given in PebbloRetrievalQA leading to prompt governance Description: Implemented identification of entities in the prompt using Pebblo prompt governance API. Issue: NA Dependencies: NA Add tests and docs: NA
This commit is contained in:
parent
a6add89bd4
commit
b00c0fc558
@ -8,7 +8,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -37,6 +37,7 @@ from langchain_community.chains.pebblo_retrieval.utilities import (
|
|||||||
CLASSIFIER_URL,
|
CLASSIFIER_URL,
|
||||||
PEBBLO_CLOUD_URL,
|
PEBBLO_CLOUD_URL,
|
||||||
PLUGIN_VERSION,
|
PLUGIN_VERSION,
|
||||||
|
PROMPT_GOV_URL,
|
||||||
PROMPT_URL,
|
PROMPT_URL,
|
||||||
get_runtime,
|
get_runtime,
|
||||||
)
|
)
|
||||||
@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain):
|
|||||||
"""Flag to check if discover payload has been sent."""
|
"""Flag to check if discover payload has been sent."""
|
||||||
_prompt_sent: bool = False #: :meta private:
|
_prompt_sent: bool = False #: :meta private:
|
||||||
"""Flag to check if prompt payload has been sent."""
|
"""Flag to check if prompt payload has been sent."""
|
||||||
|
enable_prompt_gov: bool = True #: :meta private:
|
||||||
|
"""Flag to check if prompt governance is enabled or not"""
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -102,6 +105,8 @@ class PebbloRetrievalQA(Chain):
|
|||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
auth_context = inputs.get(self.auth_context_key, {})
|
auth_context = inputs.get(self.auth_context_key, {})
|
||||||
semantic_context = inputs.get(self.semantic_context_key, {})
|
semantic_context = inputs.get(self.semantic_context_key, {})
|
||||||
|
_, prompt_entities = self._check_prompt_validity(question)
|
||||||
|
|
||||||
accepts_run_manager = (
|
accepts_run_manager = (
|
||||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||||
)
|
)
|
||||||
@ -133,7 +138,12 @@ class PebbloRetrievalQA(Chain):
|
|||||||
for doc in docs
|
for doc in docs
|
||||||
if isinstance(doc, Document)
|
if isinstance(doc, Document)
|
||||||
],
|
],
|
||||||
"prompt": {"data": question},
|
"prompt": {
|
||||||
|
"data": question,
|
||||||
|
"entities": prompt_entities.get("entities", {}),
|
||||||
|
"entityCount": prompt_entities.get("entityCount", 0),
|
||||||
|
"prompt_gov_enabled": self.enable_prompt_gov,
|
||||||
|
},
|
||||||
"response": {
|
"response": {
|
||||||
"data": answer,
|
"data": answer,
|
||||||
},
|
},
|
||||||
@ -144,6 +154,7 @@ class PebbloRetrievalQA(Chain):
|
|||||||
else [],
|
else [],
|
||||||
"classifier_location": self.classifier_location,
|
"classifier_location": self.classifier_location,
|
||||||
}
|
}
|
||||||
|
|
||||||
qa_payload = Qa(**qa)
|
qa_payload = Qa(**qa)
|
||||||
self._send_prompt(qa_payload)
|
self._send_prompt(qa_payload)
|
||||||
|
|
||||||
@ -175,6 +186,9 @@ class PebbloRetrievalQA(Chain):
|
|||||||
accepts_run_manager = (
|
accepts_run_manager = (
|
||||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_, prompt_entities = self._check_prompt_validity(question)
|
||||||
|
|
||||||
if accepts_run_manager:
|
if accepts_run_manager:
|
||||||
docs = await self._aget_docs(
|
docs = await self._aget_docs(
|
||||||
question, auth_context, semantic_context, run_manager=_run_manager
|
question, auth_context, semantic_context, run_manager=_run_manager
|
||||||
@ -513,6 +527,66 @@ class PebbloRetrievalQA(Chain):
|
|||||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||||
|
|
||||||
|
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Check the validity of the given prompt using a remote classification service.
|
||||||
|
|
||||||
|
This method sends a prompt to a remote classifier service and return entities
|
||||||
|
present in prompt or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question (str): The prompt question to be validated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the prompt is valid (does not contain deny list entities),
|
||||||
|
False otherwise.
|
||||||
|
dict: The entities present in the prompt
|
||||||
|
"""
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
prompt_payload = {"prompt": question}
|
||||||
|
is_valid_prompt: bool = True
|
||||||
|
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
|
||||||
|
pebblo_resp = None
|
||||||
|
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
try:
|
||||||
|
pebblo_resp = requests.post(
|
||||||
|
prompt_gov_api_url,
|
||||||
|
headers=headers,
|
||||||
|
json=prompt_payload,
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("prompt-payload: %s", prompt_payload)
|
||||||
|
logger.debug(
|
||||||
|
"send_prompt[local]: request url %s, body %s len %s\
|
||||||
|
response status %s body %s",
|
||||||
|
pebblo_resp.request.url,
|
||||||
|
str(pebblo_resp.request.body),
|
||||||
|
str(
|
||||||
|
len(
|
||||||
|
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||||
|
)
|
||||||
|
),
|
||||||
|
str(pebblo_resp.status_code),
|
||||||
|
pebblo_resp.json(),
|
||||||
|
)
|
||||||
|
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
|
||||||
|
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||||
|
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||||
|
"entityCount", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
logger.warning("Unable to reach pebblo server.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||||
|
return is_valid_prompt, prompt_entities
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
|
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
|
||||||
llm_dict = llm.__dict__
|
llm_dict = llm.__dict__
|
||||||
|
@ -133,7 +133,10 @@ class Context(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Prompt(BaseModel):
|
class Prompt(BaseModel):
|
||||||
data: str
|
data: Optional[Union[list, str]]
|
||||||
|
entityCount: Optional[int]
|
||||||
|
entities: Optional[dict]
|
||||||
|
prompt_gov_enabled: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class Qa(BaseModel):
|
class Qa(BaseModel):
|
||||||
|
@ -15,6 +15,7 @@ CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
|||||||
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
||||||
|
|
||||||
PROMPT_URL = "/v1/prompt"
|
PROMPT_URL = "/v1/prompt"
|
||||||
|
PROMPT_GOV_URL = "/v1/prompt/governance"
|
||||||
APP_DISCOVER_URL = "/v1/app/discover"
|
APP_DISCOVER_URL = "/v1/app/discover"
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user