mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: Refactor PebbloRetrievalQA (#25583)
**Refactor PebbloRetrievalQA** - Created `APIWrapper` and moved API logic into it. - Created smaller functions/methods for better readability. - Properly read environment variables. - Removed unused code. - Updated models **Issue:** NA **Dependencies:** NA **tests**: NA
This commit is contained in:
parent
1f1679e960
commit
4ff2f4499e
@ -5,12 +5,9 @@ against a vector database.
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain_core.callbacks import (
|
||||
@ -29,16 +26,14 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
App,
|
||||
AuthContext,
|
||||
Qa,
|
||||
ChainInfo,
|
||||
Model,
|
||||
SemanticContext,
|
||||
VectorDB,
|
||||
)
|
||||
from langchain_community.chains.pebblo_retrieval.utilities import (
|
||||
APP_DISCOVER_URL,
|
||||
CLASSIFIER_URL,
|
||||
PEBBLO_CLOUD_URL,
|
||||
PLUGIN_VERSION,
|
||||
PROMPT_GOV_URL,
|
||||
PROMPT_URL,
|
||||
PebbloRetrievalAPIWrapper,
|
||||
get_runtime,
|
||||
)
|
||||
|
||||
@ -72,16 +67,18 @@ class PebbloRetrievalQA(Chain):
|
||||
"""Description of app."""
|
||||
api_key: Optional[str] = None #: :meta private:
|
||||
"""Pebblo cloud API key for app."""
|
||||
classifier_url: str = CLASSIFIER_URL #: :meta private:
|
||||
classifier_url: Optional[str] = None #: :meta private:
|
||||
"""Classifier endpoint."""
|
||||
classifier_location: str = "local" #: :meta private:
|
||||
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
|
||||
_discover_sent: bool = False #: :meta private:
|
||||
"""Flag to check if discover payload has been sent."""
|
||||
_prompt_sent: bool = False #: :meta private:
|
||||
"""Flag to check if prompt payload has been sent."""
|
||||
enable_prompt_gov: bool = True #: :meta private:
|
||||
"""Flag to check if prompt governance is enabled or not"""
|
||||
pb_client: PebbloRetrievalAPIWrapper = Field(
|
||||
default_factory=PebbloRetrievalAPIWrapper
|
||||
)
|
||||
"""Pebblo Retrieval API client"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -100,12 +97,11 @@ class PebbloRetrievalQA(Chain):
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
prompt_time = datetime.datetime.now().isoformat()
|
||||
PebbloRetrievalQA.set_prompt_sent(value=False)
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key, {})
|
||||
semantic_context = inputs.get(self.semantic_context_key, {})
|
||||
_, prompt_entities = self._check_prompt_validity(question)
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
semantic_context = inputs.get(self.semantic_context_key)
|
||||
_, prompt_entities = self.pb_client.check_prompt_validity(question)
|
||||
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
@ -120,43 +116,17 @@ class PebbloRetrievalQA(Chain):
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
qa = {
|
||||
"name": self.app_name,
|
||||
"context": [
|
||||
{
|
||||
"retrieved_from": doc.metadata.get(
|
||||
"full_path", doc.metadata.get("source")
|
||||
),
|
||||
"doc": doc.page_content,
|
||||
"vector_db": self.retriever.vectorstore.__class__.__name__,
|
||||
**(
|
||||
{"pb_checksum": doc.metadata.get("pb_checksum")}
|
||||
if doc.metadata.get("pb_checksum")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for doc in docs
|
||||
if isinstance(doc, Document)
|
||||
],
|
||||
"prompt": {
|
||||
"data": question,
|
||||
"entities": prompt_entities.get("entities", {}),
|
||||
"entityCount": prompt_entities.get("entityCount", 0),
|
||||
"prompt_gov_enabled": self.enable_prompt_gov,
|
||||
},
|
||||
"response": {
|
||||
"data": answer,
|
||||
},
|
||||
"prompt_time": prompt_time,
|
||||
"user": auth_context.user_id if auth_context else "unknown",
|
||||
"user_identities": auth_context.user_auth
|
||||
if auth_context and hasattr(auth_context, "user_auth")
|
||||
else [],
|
||||
"classifier_location": self.classifier_location,
|
||||
}
|
||||
|
||||
qa_payload = Qa(**qa)
|
||||
self._send_prompt(qa_payload)
|
||||
self.pb_client.send_prompt(
|
||||
self.app_name,
|
||||
self.retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
self.enable_prompt_gov,
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
@ -187,7 +157,7 @@ class PebbloRetrievalQA(Chain):
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
|
||||
_, prompt_entities = self._check_prompt_validity(question)
|
||||
_, prompt_entities = self.pb_client.check_prompt_validity(question)
|
||||
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(
|
||||
@ -243,7 +213,7 @@ class PebbloRetrievalQA(Chain):
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
api_key: Optional[str] = None,
|
||||
classifier_url: str = CLASSIFIER_URL,
|
||||
classifier_url: Optional[str] = None,
|
||||
classifier_location: str = "local",
|
||||
**kwargs: Any,
|
||||
) -> "PebbloRetrievalQA":
|
||||
@ -263,14 +233,14 @@ class PebbloRetrievalQA(Chain):
|
||||
llm=llm,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
PebbloRetrievalQA._send_discover(
|
||||
app,
|
||||
# initialize Pebblo API client
|
||||
pb_client = PebbloRetrievalAPIWrapper(
|
||||
api_key=api_key,
|
||||
classifier_url=classifier_url,
|
||||
classifier_location=classifier_location,
|
||||
classifier_url=classifier_url,
|
||||
)
|
||||
|
||||
# send app discovery request
|
||||
pb_client.send_app_discover(app)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
app_name=app_name,
|
||||
@ -279,6 +249,7 @@ class PebbloRetrievalQA(Chain):
|
||||
api_key=api_key,
|
||||
classifier_url=classifier_url,
|
||||
classifier_location=classifier_location,
|
||||
pb_client=pb_client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -346,259 +317,36 @@ class PebbloRetrievalQA(Chain):
|
||||
)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def _send_discover(
|
||||
app: App,
|
||||
api_key: Optional[str],
|
||||
classifier_url: str,
|
||||
classifier_location: str,
|
||||
) -> None: # type: ignore
|
||||
"""Send app discovery payload to pebblo-server. Internal method."""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = app.dict(exclude_unset=True)
|
||||
if classifier_location == "local":
|
||||
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
logger.debug("discover-payload: %s", payload)
|
||||
logger.debug(
|
||||
"send_discover[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloRetrievalQA.set_discover_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected HTTP response code:"
|
||||
+ f"{pebblo_resp.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
|
||||
if api_key:
|
||||
try:
|
||||
headers.update({"x-api-key": api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"send_discover[cloud]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_cloud_response.request.url,
|
||||
str(pebblo_cloud_response.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_cloud_response.request.body
|
||||
if pebblo_cloud_response.request.body
|
||||
else []
|
||||
)
|
||||
),
|
||||
str(pebblo_cloud_response.status_code),
|
||||
pebblo_cloud_response.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: cloud %s", e)
|
||||
|
||||
@classmethod
|
||||
def set_discover_sent(cls) -> None:
|
||||
cls._discover_sent = True
|
||||
|
||||
@classmethod
|
||||
def set_prompt_sent(cls, value: bool = True) -> None:
|
||||
cls._prompt_sent = value
|
||||
|
||||
def _send_prompt(self, qa_payload: Qa) -> None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
|
||||
pebblo_resp = None
|
||||
payload = qa_payload.dict(exclude_unset=True)
|
||||
if self.classifier_location == "local":
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=20,
|
||||
)
|
||||
logger.debug("prompt-payload: %s", payload)
|
||||
logger.debug(
|
||||
"send_prompt[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloRetrievalQA.set_prompt_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected HTTP response code:"
|
||||
+ f"{pebblo_resp.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
|
||||
# If classifier location is local, then response, context and prompt
|
||||
# should be fetched from pebblo_resp and replaced in payload.
|
||||
if self.api_key:
|
||||
if self.classifier_location == "local":
|
||||
if pebblo_resp:
|
||||
resp = json.loads(pebblo_resp.text)
|
||||
if resp:
|
||||
payload["response"].update(
|
||||
resp.get("retrieval_data", {}).get("response", {})
|
||||
)
|
||||
payload["response"].pop("data")
|
||||
payload["prompt"].update(
|
||||
resp.get("retrieval_data", {}).get("prompt", {})
|
||||
)
|
||||
payload["prompt"].pop("data")
|
||||
context = payload["context"]
|
||||
for context_data in context:
|
||||
context_data.pop("doc")
|
||||
payload["context"] = context
|
||||
else:
|
||||
payload["response"] = {}
|
||||
payload["prompt"] = {}
|
||||
payload["context"] = []
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
|
||||
try:
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"send_prompt[cloud]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_cloud_response.request.url,
|
||||
str(pebblo_cloud_response.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_cloud_response.request.body
|
||||
if pebblo_cloud_response.request.body
|
||||
else []
|
||||
)
|
||||
),
|
||||
str(pebblo_cloud_response.status_code),
|
||||
pebblo_cloud_response.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||
def get_chain_details(
|
||||
cls, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> List[ChainInfo]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
|
||||
This method sends a prompt to a remote classifier service and return entities
|
||||
present in prompt or not.
|
||||
Get chain details.
|
||||
|
||||
Args:
|
||||
question (str): The prompt question to be validated.
|
||||
llm (BaseLanguageModel): Language model instance.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is valid (does not contain deny list entities),
|
||||
False otherwise.
|
||||
dict: The entities present in the prompt
|
||||
List[ChainInfo]: Chain details.
|
||||
"""
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
prompt_payload = {"prompt": question}
|
||||
is_valid_prompt: bool = True
|
||||
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
|
||||
pebblo_resp = None
|
||||
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||
if self.classifier_location == "local":
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
prompt_gov_api_url,
|
||||
headers=headers,
|
||||
json=prompt_payload,
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
logger.debug("prompt-payload: %s", prompt_payload)
|
||||
logger.debug(
|
||||
"send_prompt[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
|
||||
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||
"entityCount", 0
|
||||
)
|
||||
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
@classmethod
|
||||
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
|
||||
llm_dict = llm.__dict__
|
||||
chain = [
|
||||
{
|
||||
"name": cls.__name__,
|
||||
"model": {
|
||||
"name": llm_dict.get("model_name", llm_dict.get("model")),
|
||||
"vendor": llm.__class__.__name__,
|
||||
},
|
||||
"vector_dbs": [
|
||||
{
|
||||
"name": kwargs["retriever"].vectorstore.__class__.__name__,
|
||||
"embedding_model": str(
|
||||
chains = [
|
||||
ChainInfo(
|
||||
name=cls.__name__,
|
||||
model=Model(
|
||||
name=llm_dict.get("model_name", llm_dict.get("model")),
|
||||
vendor=llm.__class__.__name__,
|
||||
),
|
||||
vector_dbs=[
|
||||
VectorDB(
|
||||
name=kwargs["retriever"].vectorstore.__class__.__name__,
|
||||
embedding_model=str(
|
||||
kwargs["retriever"].vectorstore._embeddings.model
|
||||
)
|
||||
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
|
||||
@ -607,8 +355,8 @@ class PebbloRetrievalQA(Chain):
|
||||
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
],
|
||||
},
|
||||
),
|
||||
]
|
||||
return chain
|
||||
return chains
|
||||
|
@ -109,7 +109,7 @@ class VectorDB(BaseModel):
|
||||
embedding_model: Optional[str] = None
|
||||
|
||||
|
||||
class Chains(BaseModel):
|
||||
class ChainInfo(BaseModel):
|
||||
name: str
|
||||
model: Optional[Model]
|
||||
vector_dbs: Optional[List[VectorDB]]
|
||||
@ -121,7 +121,7 @@ class App(BaseModel):
|
||||
description: Optional[str]
|
||||
runtime: Runtime
|
||||
framework: Framework
|
||||
chains: List[Chains]
|
||||
chains: List[ChainInfo]
|
||||
plugin_version: str
|
||||
|
||||
|
||||
@ -134,9 +134,9 @@ class Context(BaseModel):
|
||||
|
||||
class Prompt(BaseModel):
|
||||
data: Optional[Union[list, str]]
|
||||
entityCount: Optional[int]
|
||||
entities: Optional[dict]
|
||||
prompt_gov_enabled: Optional[bool]
|
||||
entityCount: Optional[int] = None
|
||||
entities: Optional[dict] = None
|
||||
prompt_gov_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class Qa(BaseModel):
|
||||
|
@ -1,22 +1,43 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from typing import Tuple
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from requests import Response, request
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
App,
|
||||
AuthContext,
|
||||
Context,
|
||||
Framework,
|
||||
Prompt,
|
||||
Qa,
|
||||
Runtime,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGIN_VERSION = "0.1.1"
|
||||
|
||||
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
||||
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
||||
_DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
|
||||
_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
|
||||
|
||||
PROMPT_URL = "/v1/prompt"
|
||||
PROMPT_GOV_URL = "/v1/prompt/governance"
|
||||
APP_DISCOVER_URL = "/v1/app/discover"
|
||||
|
||||
class Routes(str, Enum):
|
||||
"""Routes available for the Pebblo API as enumerator."""
|
||||
|
||||
retrieval_app_discover = "/v1/app/discover"
|
||||
prompt = "/v1/prompt"
|
||||
prompt_governance = "/v1/prompt/governance"
|
||||
|
||||
|
||||
def get_runtime() -> Tuple[Framework, Runtime]:
|
||||
@ -64,3 +85,308 @@ def get_ip() -> str:
|
||||
except Exception:
|
||||
public_ip = socket.gethostbyname("localhost")
|
||||
return public_ip
|
||||
|
||||
|
||||
class PebbloRetrievalAPIWrapper(BaseModel):
|
||||
"""Wrapper for Pebblo Retrieval API."""
|
||||
|
||||
api_key: Optional[str] # Use SecretStr
|
||||
"""API key for Pebblo Cloud"""
|
||||
classifier_location: str = "local"
|
||||
"""Location of the classifier, local or cloud. Defaults to 'local'"""
|
||||
classifier_url: Optional[str]
|
||||
"""URL of the Pebblo Classifier"""
|
||||
cloud_url: Optional[str]
|
||||
"""URL of the Pebblo Cloud"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Validate that api key in environment."""
|
||||
kwargs["api_key"] = get_from_dict_or_env(
|
||||
kwargs, "api_key", "PEBBLO_API_KEY", ""
|
||||
)
|
||||
kwargs["classifier_url"] = get_from_dict_or_env(
|
||||
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
|
||||
)
|
||||
kwargs["cloud_url"] = get_from_dict_or_env(
|
||||
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def send_app_discover(self, app: App) -> None:
|
||||
"""
|
||||
Send app discovery request to Pebblo server & cloud.
|
||||
|
||||
Args:
|
||||
app (App): App instance to be discovered.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = app.dict(exclude_unset=True)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send app details to local classifier
|
||||
headers = self._make_headers()
|
||||
app_discover_url = f"{self.classifier_url}{Routes.retrieval_app_discover}"
|
||||
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send app details to Pebblo cloud if api_key is present
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
if pebblo_resp:
|
||||
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
||||
"pebblo_server_version"
|
||||
)
|
||||
payload.update({"pebblo_server_version": pebblo_server_version})
|
||||
|
||||
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
||||
pebblo_cloud_url = f"{self.cloud_url}{Routes.retrieval_app_discover}"
|
||||
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
|
||||
|
||||
def send_prompt(
|
||||
self,
|
||||
app_name: str,
|
||||
retriever: VectorStoreRetriever,
|
||||
question: str,
|
||||
answer: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
docs: List[Document],
|
||||
prompt_entities: Dict[str, Any],
|
||||
prompt_time: str,
|
||||
prompt_gov_enabled: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Send prompt to Pebblo server for classification.
|
||||
Then send prompt to Daxa cloud(If api_key is present).
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the app.
|
||||
retriever (VectorStoreRetriever): Retriever instance.
|
||||
question (str): Question asked in the prompt.
|
||||
answer (str): Answer generated by the model.
|
||||
auth_context (Optional[AuthContext]): Authentication context.
|
||||
docs (List[Document]): List of documents retrieved.
|
||||
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||
prompt_time (str): Time when the prompt was generated.
|
||||
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = self.build_prompt_qa_payload(
|
||||
app_name,
|
||||
retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
prompt_gov_enabled,
|
||||
)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send prompt to local classifier
|
||||
headers = self._make_headers()
|
||||
prompt_url = f"{self.classifier_url}{Routes.prompt}"
|
||||
pebblo_resp = self.make_request("POST", prompt_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send prompt to Pebblo cloud if api_key is present
|
||||
if self.classifier_location == "local":
|
||||
# If classifier location is local, then response, context and prompt
|
||||
# should be fetched from pebblo_resp and replaced in payload.
|
||||
pebblo_resp = pebblo_resp.json() if pebblo_resp else None
|
||||
self.update_cloud_payload(payload, pebblo_resp)
|
||||
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt}"
|
||||
_ = self.make_request("POST", pebblo_cloud_prompt_url, headers, payload)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
|
||||
This method sends a prompt to a remote classifier service and return entities
|
||||
present in prompt or not.
|
||||
|
||||
Args:
|
||||
question (str): The prompt question to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is valid (does not contain deny list entities),
|
||||
False otherwise.
|
||||
dict: The entities present in the prompt
|
||||
"""
|
||||
prompt_payload = {"prompt": question}
|
||||
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||
is_valid_prompt: bool = True
|
||||
if self.classifier_location == "local":
|
||||
headers = self._make_headers()
|
||||
prompt_gov_api_url = f"{self.classifier_url}{Routes.prompt_governance}"
|
||||
pebblo_resp = self.make_request(
|
||||
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||
)
|
||||
if pebblo_resp:
|
||||
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
|
||||
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||
"entityCount", 0
|
||||
)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
def _make_headers(self, cloud_request: bool = False) -> dict:
|
||||
"""
|
||||
Generate headers for the request.
|
||||
|
||||
args:
|
||||
cloud_request (bool): flag indicating whether the request is for Pebblo
|
||||
cloud.
|
||||
returns:
|
||||
dict: Headers for the request.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if cloud_request:
|
||||
# Add API key for Pebblo cloud request
|
||||
if self.api_key:
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
else:
|
||||
logger.warning("API key is missing for Pebblo cloud request.")
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def make_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
payload: Optional[dict] = None,
|
||||
timeout: int = 20,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Make a request to the Pebblo server/cloud API.
|
||||
|
||||
Args:
|
||||
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
url (str): URL for the request.
|
||||
headers (dict): Headers for the request.
|
||||
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||
timeout (int): Timeout for the request in seconds.
|
||||
|
||||
Returns:
|
||||
Optional[Response]: Response object if the request is successful.
|
||||
"""
|
||||
try:
|
||||
response = request(
|
||||
method=method, url=url, headers=headers, json=payload, timeout=timeout
|
||||
)
|
||||
logger.debug(
|
||||
"Request: method %s, url %s, len %s response status %s",
|
||||
method,
|
||||
response.request.url,
|
||||
str(len(response.request.body if response.request.body else [])),
|
||||
str(response.status_code),
|
||||
)
|
||||
|
||||
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
logger.warning(f"Pebblo Server: Error {response.status_code}")
|
||||
elif response.status_code >= HTTPStatus.BAD_REQUEST:
|
||||
logger.warning(f"Pebblo received an invalid payload: {response.text}")
|
||||
elif response.status_code != HTTPStatus.OK:
|
||||
logger.warning(
|
||||
f"Pebblo returned an unexpected response code: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
|
||||
return response
|
||||
except RequestException:
|
||||
logger.warning("Unable to reach server %s", url)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in make_request: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def update_cloud_payload(payload: dict, pebblo_resp: Optional[dict]) -> None:
|
||||
"""
|
||||
Update the payload with response, prompt and context from Pebblo response.
|
||||
|
||||
Args:
|
||||
payload (dict): Payload to be updated.
|
||||
pebblo_resp (Optional[dict]): Response from Pebblo server.
|
||||
"""
|
||||
if pebblo_resp:
|
||||
# Update response, prompt and context from pebblo response
|
||||
response = payload.get("response", {})
|
||||
response.update(pebblo_resp.get("retrieval_data", {}).get("response", {}))
|
||||
response.pop("data", None)
|
||||
prompt = payload.get("prompt", {})
|
||||
prompt.update(pebblo_resp.get("retrieval_data", {}).get("prompt", {}))
|
||||
prompt.pop("data", None)
|
||||
context = payload.get("context", [])
|
||||
for context_data in context:
|
||||
context_data.pop("doc", None)
|
||||
else:
|
||||
payload["response"] = {}
|
||||
payload["prompt"] = {}
|
||||
payload["context"] = []
|
||||
|
||||
def build_prompt_qa_payload(
|
||||
self,
|
||||
app_name: str,
|
||||
retriever: VectorStoreRetriever,
|
||||
question: str,
|
||||
answer: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
docs: List[Document],
|
||||
prompt_entities: Dict[str, Any],
|
||||
prompt_time: str,
|
||||
prompt_gov_enabled: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Build the QA payload for the prompt.
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the app.
|
||||
retriever (VectorStoreRetriever): Retriever instance.
|
||||
question (str): Question asked in the prompt.
|
||||
answer (str): Answer generated by the model.
|
||||
auth_context (Optional[AuthContext]): Authentication context.
|
||||
docs (List[Document]): List of documents retrieved.
|
||||
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||
prompt_time (str): Time when the prompt was generated.
|
||||
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||
|
||||
Returns:
|
||||
dict: The QA payload for the prompt.
|
||||
"""
|
||||
qa = Qa(
|
||||
name=app_name,
|
||||
context=[
|
||||
Context(
|
||||
retrieved_from=doc.metadata.get(
|
||||
"full_path", doc.metadata.get("source")
|
||||
),
|
||||
doc=doc.page_content,
|
||||
vector_db=retriever.vectorstore.__class__.__name__,
|
||||
pb_checksum=doc.metadata.get("pb_checksum"),
|
||||
)
|
||||
for doc in docs
|
||||
if isinstance(doc, Document)
|
||||
],
|
||||
prompt=Prompt(
|
||||
data=question,
|
||||
entities=prompt_entities.get("entities", {}),
|
||||
entityCount=prompt_entities.get("entityCount", 0),
|
||||
prompt_gov_enabled=prompt_gov_enabled,
|
||||
),
|
||||
response=Prompt(data=answer),
|
||||
prompt_time=prompt_time,
|
||||
user=auth_context.user_id if auth_context else "unknown",
|
||||
user_identities=auth_context.user_auth
|
||||
if auth_context and hasattr(auth_context, "user_auth")
|
||||
else [],
|
||||
classifier_location=self.classifier_location,
|
||||
)
|
||||
return qa.dict(exclude_unset=True)
|
||||
|
Loading…
Reference in New Issue
Block a user