community: Refactor PebbloRetrievalQA (#25583)

**Refactor PebbloRetrievalQA**
  - Created `APIWrapper` and moved API logic into it.
  - Created smaller functions/methods for better readability.
  - Properly read environment variables.
  - Removed unused code.
  - Updated models

**Issue:** NA
**Dependencies:** NA
**tests**:  NA
This commit is contained in:
Rajendra Kadam 2024-08-22 21:21:21 +05:30 committed by GitHub
parent 1f1679e960
commit 4ff2f4499e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 391 additions and 317 deletions

View File

@ -5,12 +5,9 @@ against a vector database.
import datetime
import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import requests # type: ignore
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
@ -29,16 +26,14 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
Qa,
ChainInfo,
Model,
SemanticContext,
VectorDB,
)
from langchain_community.chains.pebblo_retrieval.utilities import (
APP_DISCOVER_URL,
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
PROMPT_GOV_URL,
PROMPT_URL,
PebbloRetrievalAPIWrapper,
get_runtime,
)
@ -72,16 +67,18 @@ class PebbloRetrievalQA(Chain):
"""Description of app."""
api_key: Optional[str] = None #: :meta private:
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
classifier_url: Optional[str] = None #: :meta private:
"""Classifier endpoint."""
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent: bool = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not"""
pb_client: PebbloRetrievalAPIWrapper = Field(
default_factory=PebbloRetrievalAPIWrapper
)
"""Pebblo Retrieval API client"""
def _call(
self,
@ -100,12 +97,11 @@ class PebbloRetrievalQA(Chain):
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
PebbloRetrievalQA.set_prompt_sent(value=False)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})
_, prompt_entities = self._check_prompt_validity(question)
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
_, prompt_entities = self.pb_client.check_prompt_validity(question)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
@ -120,43 +116,17 @@ class PebbloRetrievalQA(Chain):
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
qa = {
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get(
"full_path", doc.metadata.get("source")
),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
**(
{"pb_checksum": doc.metadata.get("pb_checksum")}
if doc.metadata.get("pb_checksum")
else {}
),
}
for doc in docs
if isinstance(doc, Document)
],
"prompt": {
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": self.enable_prompt_gov,
},
"response": {
"data": answer,
},
"prompt_time": prompt_time,
"user": auth_context.user_id if auth_context else "unknown",
"user_identities": auth_context.user_auth
if auth_context and hasattr(auth_context, "user_auth")
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
self.pb_client.send_prompt(
self.app_name,
self.retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
self.enable_prompt_gov,
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
@ -187,7 +157,7 @@ class PebbloRetrievalQA(Chain):
"run_manager" in inspect.signature(self._aget_docs).parameters
)
_, prompt_entities = self._check_prompt_validity(question)
_, prompt_entities = self.pb_client.check_prompt_validity(question)
if accepts_run_manager:
docs = await self._aget_docs(
@ -243,7 +213,7 @@ class PebbloRetrievalQA(Chain):
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
classifier_url: Optional[str] = None,
classifier_location: str = "local",
**kwargs: Any,
) -> "PebbloRetrievalQA":
@ -263,14 +233,14 @@ class PebbloRetrievalQA(Chain):
llm=llm,
**kwargs,
)
PebbloRetrievalQA._send_discover(
app,
# initialize Pebblo API client
pb_client = PebbloRetrievalAPIWrapper(
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
classifier_url=classifier_url,
)
# send app discovery request
pb_client.send_app_discover(app)
return cls(
combine_documents_chain=combine_documents_chain,
app_name=app_name,
@ -279,6 +249,7 @@ class PebbloRetrievalQA(Chain):
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
pb_client=pb_client,
**kwargs,
)
@ -346,259 +317,36 @@ class PebbloRetrievalQA(Chain):
)
return app
@staticmethod
def _send_discover(
app: App,
api_key: Optional[str],
classifier_url: str,
classifier_location: str,
) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
if classifier_location == "local":
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if api_key:
try:
headers.update({"x-api-key": api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_discover[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: cloud %s", e)
@classmethod
def set_discover_sent(cls) -> None:
cls._discover_sent = True
@classmethod
def set_prompt_sent(cls, value: bool = True) -> None:
cls._prompt_sent = value
def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
pebblo_resp = None
payload = qa_payload.dict(exclude_unset=True)
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=payload,
timeout=20,
)
logger.debug("prompt-payload: %s", payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
resp = json.loads(pebblo_resp.text)
if resp:
payload["response"].update(
resp.get("retrieval_data", {}).get("response", {})
)
payload["response"].pop("data")
payload["prompt"].update(
resp.get("retrieval_data", {}).get("prompt", {})
)
payload["prompt"].pop("data")
context = payload["context"]
for context_data in context:
context_data.pop("doc")
payload["context"] = context
else:
payload["response"] = {}
payload["prompt"] = {}
payload["context"] = []
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=payload,
timeout=20,
)
logger.debug(
"send_prompt[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
def get_chain_details(
cls, llm: BaseLanguageModel, **kwargs: Any
) -> List[ChainInfo]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Get chain details.
Args:
question (str): The prompt question to be validated.
llm (BaseLanguageModel): Language model instance.
**kwargs: Additional keyword arguments.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
List[ChainInfo]: Chain details.
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt: bool = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
prompt_entities: dict = {"entities": {}, "entityCount": 0}
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)
logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
return is_valid_prompt, prompt_entities
@classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__
chain = [
{
"name": cls.__name__,
"model": {
"name": llm_dict.get("model_name", llm_dict.get("model")),
"vendor": llm.__class__.__name__,
},
"vector_dbs": [
{
"name": kwargs["retriever"].vectorstore.__class__.__name__,
"embedding_model": str(
chains = [
ChainInfo(
name=cls.__name__,
model=Model(
name=llm_dict.get("model_name", llm_dict.get("model")),
vendor=llm.__class__.__name__,
),
vector_dbs=[
VectorDB(
name=kwargs["retriever"].vectorstore.__class__.__name__,
embedding_model=str(
kwargs["retriever"].vectorstore._embeddings.model
)
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
@ -607,8 +355,8 @@ class PebbloRetrievalQA(Chain):
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
else None
),
}
)
],
},
),
]
return chain
return chains

View File

@ -109,7 +109,7 @@ class VectorDB(BaseModel):
embedding_model: Optional[str] = None
class Chains(BaseModel):
class ChainInfo(BaseModel):
name: str
model: Optional[Model]
vector_dbs: Optional[List[VectorDB]]
@ -121,7 +121,7 @@ class App(BaseModel):
description: Optional[str]
runtime: Runtime
framework: Framework
chains: List[Chains]
chains: List[ChainInfo]
plugin_version: str
@ -134,9 +134,9 @@ class Context(BaseModel):
class Prompt(BaseModel):
data: Optional[Union[list, str]]
entityCount: Optional[int]
entities: Optional[dict]
prompt_gov_enabled: Optional[bool]
entityCount: Optional[int] = None
entities: Optional[dict] = None
prompt_gov_enabled: Optional[bool] = None
class Qa(BaseModel):

View File

@ -1,22 +1,43 @@
import json
import logging
import os
import platform
from typing import Tuple
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.env import get_runtime_environment
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStoreRetriever
from requests import Response, request
from requests.exceptions import RequestException
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime
from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
Context,
Framework,
Prompt,
Qa,
Runtime,
)
logger = logging.getLogger(__name__)
PLUGIN_VERSION = "0.1.1"
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
_DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
PROMPT_URL = "/v1/prompt"
PROMPT_GOV_URL = "/v1/prompt/governance"
APP_DISCOVER_URL = "/v1/app/discover"
class Routes(str, Enum):
"""Routes available for the Pebblo API as enumerator."""
retrieval_app_discover = "/v1/app/discover"
prompt = "/v1/prompt"
prompt_governance = "/v1/prompt/governance"
def get_runtime() -> Tuple[Framework, Runtime]:
@ -64,3 +85,308 @@ def get_ip() -> str:
except Exception:
public_ip = socket.gethostbyname("localhost")
return public_ip
class PebbloRetrievalAPIWrapper(BaseModel):
"""Wrapper for Pebblo Retrieval API."""
api_key: Optional[str] # Use SecretStr
"""API key for Pebblo Cloud"""
classifier_location: str = "local"
"""Location of the classifier, local or cloud. Defaults to 'local'"""
classifier_url: Optional[str]
"""URL of the Pebblo Classifier"""
cloud_url: Optional[str]
"""URL of the Pebblo Cloud"""
def __init__(self, **kwargs: Any):
"""Validate that api key in environment."""
kwargs["api_key"] = get_from_dict_or_env(
kwargs, "api_key", "PEBBLO_API_KEY", ""
)
kwargs["classifier_url"] = get_from_dict_or_env(
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
)
kwargs["cloud_url"] = get_from_dict_or_env(
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
)
super().__init__(**kwargs)
def send_app_discover(self, app: App) -> None:
"""
Send app discovery request to Pebblo server & cloud.
Args:
app (App): App instance to be discovered.
"""
pebblo_resp = None
payload = app.dict(exclude_unset=True)
if self.classifier_location == "local":
# Send app details to local classifier
headers = self._make_headers()
app_discover_url = f"{self.classifier_url}{Routes.retrieval_app_discover}"
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
if self.api_key:
# Send app details to Pebblo cloud if api_key is present
headers = self._make_headers(cloud_request=True)
if pebblo_resp:
pebblo_server_version = json.loads(pebblo_resp.text).get(
"pebblo_server_version"
)
payload.update({"pebblo_server_version": pebblo_server_version})
payload.update({"pebblo_client_version": PLUGIN_VERSION})
pebblo_cloud_url = f"{self.cloud_url}{Routes.retrieval_app_discover}"
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
def send_prompt(
self,
app_name: str,
retriever: VectorStoreRetriever,
question: str,
answer: str,
auth_context: Optional[AuthContext],
docs: List[Document],
prompt_entities: Dict[str, Any],
prompt_time: str,
prompt_gov_enabled: bool = False,
) -> None:
"""
Send prompt to Pebblo server for classification.
Then send prompt to Daxa cloud(If api_key is present).
Args:
app_name (str): Name of the app.
retriever (VectorStoreRetriever): Retriever instance.
question (str): Question asked in the prompt.
answer (str): Answer generated by the model.
auth_context (Optional[AuthContext]): Authentication context.
docs (List[Document]): List of documents retrieved.
prompt_entities (Dict[str, Any]): Entities present in the prompt.
prompt_time (str): Time when the prompt was generated.
prompt_gov_enabled (bool): Whether prompt governance is enabled.
"""
pebblo_resp = None
payload = self.build_prompt_qa_payload(
app_name,
retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
prompt_gov_enabled,
)
if self.classifier_location == "local":
# Send prompt to local classifier
headers = self._make_headers()
prompt_url = f"{self.classifier_url}{Routes.prompt}"
pebblo_resp = self.make_request("POST", prompt_url, headers, payload)
if self.api_key:
# Send prompt to Pebblo cloud if api_key is present
if self.classifier_location == "local":
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
pebblo_resp = pebblo_resp.json() if pebblo_resp else None
self.update_cloud_payload(payload, pebblo_resp)
headers = self._make_headers(cloud_request=True)
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt}"
_ = self.make_request("POST", pebblo_cloud_prompt_url, headers, payload)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
prompt_payload = {"prompt": question}
prompt_entities: dict = {"entities": {}, "entityCount": 0}
is_valid_prompt: bool = True
if self.classifier_location == "local":
headers = self._make_headers()
prompt_gov_api_url = f"{self.classifier_url}{Routes.prompt_governance}"
pebblo_resp = self.make_request(
"POST", prompt_gov_api_url, headers, prompt_payload
)
if pebblo_resp:
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
return is_valid_prompt, prompt_entities
def _make_headers(self, cloud_request: bool = False) -> dict:
"""
Generate headers for the request.
args:
cloud_request (bool): flag indicating whether the request is for Pebblo
cloud.
returns:
dict: Headers for the request.
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if cloud_request:
# Add API key for Pebblo cloud request
if self.api_key:
headers.update({"x-api-key": self.api_key})
else:
logger.warning("API key is missing for Pebblo cloud request.")
return headers
@staticmethod
def make_request(
method: str,
url: str,
headers: dict,
payload: Optional[dict] = None,
timeout: int = 20,
) -> Optional[Response]:
"""
Make a request to the Pebblo server/cloud API.
Args:
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
url (str): URL for the request.
headers (dict): Headers for the request.
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
timeout (int): Timeout for the request in seconds.
Returns:
Optional[Response]: Response object if the request is successful.
"""
try:
response = request(
method=method, url=url, headers=headers, json=payload, timeout=timeout
)
logger.debug(
"Request: method %s, url %s, len %s response status %s",
method,
response.request.url,
str(len(response.request.body if response.request.body else [])),
str(response.status_code),
)
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
logger.warning(f"Pebblo Server: Error {response.status_code}")
elif response.status_code >= HTTPStatus.BAD_REQUEST:
logger.warning(f"Pebblo received an invalid payload: {response.text}")
elif response.status_code != HTTPStatus.OK:
logger.warning(
f"Pebblo returned an unexpected response code: "
f"{response.status_code}"
)
return response
except RequestException:
logger.warning("Unable to reach server %s", url)
except Exception as e:
logger.warning("An Exception caught in make_request: %s", e)
return None
@staticmethod
def update_cloud_payload(payload: dict, pebblo_resp: Optional[dict]) -> None:
"""
Update the payload with response, prompt and context from Pebblo response.
Args:
payload (dict): Payload to be updated.
pebblo_resp (Optional[dict]): Response from Pebblo server.
"""
if pebblo_resp:
# Update response, prompt and context from pebblo response
response = payload.get("response", {})
response.update(pebblo_resp.get("retrieval_data", {}).get("response", {}))
response.pop("data", None)
prompt = payload.get("prompt", {})
prompt.update(pebblo_resp.get("retrieval_data", {}).get("prompt", {}))
prompt.pop("data", None)
context = payload.get("context", [])
for context_data in context:
context_data.pop("doc", None)
else:
payload["response"] = {}
payload["prompt"] = {}
payload["context"] = []
def build_prompt_qa_payload(
self,
app_name: str,
retriever: VectorStoreRetriever,
question: str,
answer: str,
auth_context: Optional[AuthContext],
docs: List[Document],
prompt_entities: Dict[str, Any],
prompt_time: str,
prompt_gov_enabled: bool = False,
) -> dict:
"""
Build the QA payload for the prompt.
Args:
app_name (str): Name of the app.
retriever (VectorStoreRetriever): Retriever instance.
question (str): Question asked in the prompt.
answer (str): Answer generated by the model.
auth_context (Optional[AuthContext]): Authentication context.
docs (List[Document]): List of documents retrieved.
prompt_entities (Dict[str, Any]): Entities present in the prompt.
prompt_time (str): Time when the prompt was generated.
prompt_gov_enabled (bool): Whether prompt governance is enabled.
Returns:
dict: The QA payload for the prompt.
"""
qa = Qa(
name=app_name,
context=[
Context(
retrieved_from=doc.metadata.get(
"full_path", doc.metadata.get("source")
),
doc=doc.page_content,
vector_db=retriever.vectorstore.__class__.__name__,
pb_checksum=doc.metadata.get("pb_checksum"),
)
for doc in docs
if isinstance(doc, Document)
],
prompt=Prompt(
data=question,
entities=prompt_entities.get("entities", {}),
entityCount=prompt_entities.get("entityCount", 0),
prompt_gov_enabled=prompt_gov_enabled,
),
response=Prompt(data=answer),
prompt_time=prompt_time,
user=auth_context.user_id if auth_context else "unknown",
user_identities=auth_context.user_auth
if auth_context and hasattr(auth_context, "user_auth")
else [],
classifier_location=self.classifier_location,
)
return qa.dict(exclude_unset=True)