|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
from copy import deepcopy
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
import aiohttp
|
|
|
@ -12,8 +13,6 @@ from aiohttp import ServerTimeoutError
|
|
|
|
from pydantic import BaseModel, Field, root_validator
|
|
|
|
from pydantic import BaseModel, Field, root_validator
|
|
|
|
from requests.exceptions import Timeout
|
|
|
|
from requests.exceptions import Timeout
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.tools.powerbi.prompt import SCHEMA_ERROR_RESPONSE, UNAUTHORIZED_RESPONSE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg")
|
|
|
|
BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg")
|
|
|
@ -63,28 +62,30 @@ class PowerBIDataset(BaseModel):
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def headers(self) -> Dict[str, str]:
|
|
|
|
def headers(self) -> Dict[str, str]:
|
|
|
|
"""Get the token."""
|
|
|
|
"""Get the token."""
|
|
|
|
from azure.core.exceptions import ClientAuthenticationError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token = None
|
|
|
|
|
|
|
|
if self.token:
|
|
|
|
if self.token:
|
|
|
|
token = self.token
|
|
|
|
return {
|
|
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
|
|
"Authorization": "Bearer " + self.token,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
from azure.core.exceptions import ( # pylint: disable=import-outside-toplevel
|
|
|
|
|
|
|
|
ClientAuthenticationError,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if self.credential:
|
|
|
|
if self.credential:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
token = self.credential.get_token(
|
|
|
|
token = self.credential.get_token(
|
|
|
|
"https://analysis.windows.net/powerbi/api/.default"
|
|
|
|
"https://analysis.windows.net/powerbi/api/.default"
|
|
|
|
).token
|
|
|
|
).token
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
|
|
"Authorization": "Bearer " + token,
|
|
|
|
|
|
|
|
}
|
|
|
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
|
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
|
|
raise ClientAuthenticationError(
|
|
|
|
raise ClientAuthenticationError(
|
|
|
|
"Could not get a token from the supplied credentials."
|
|
|
|
"Could not get a token from the supplied credentials."
|
|
|
|
) from exc
|
|
|
|
) from exc
|
|
|
|
if not token:
|
|
|
|
|
|
|
|
raise ClientAuthenticationError("No credential or token supplied.")
|
|
|
|
raise ClientAuthenticationError("No credential or token supplied.")
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
|
|
"Authorization": "Bearer " + token,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_table_names(self) -> Iterable[str]:
|
|
|
|
def get_table_names(self) -> Iterable[str]:
|
|
|
|
"""Get names of tables available."""
|
|
|
|
"""Get names of tables available."""
|
|
|
|
return self.table_names
|
|
|
|
return self.table_names
|
|
|
@ -116,10 +117,12 @@ class PowerBIDataset(BaseModel):
|
|
|
|
return self.table_names
|
|
|
|
return self.table_names
|
|
|
|
|
|
|
|
|
|
|
|
def _get_tables_todo(self, tables_todo: List[str]) -> List[str]:
|
|
|
|
def _get_tables_todo(self, tables_todo: List[str]) -> List[str]:
|
|
|
|
for table in tables_todo:
|
|
|
|
"""Get the tables that still need to be queried."""
|
|
|
|
|
|
|
|
todo = deepcopy(tables_todo)
|
|
|
|
|
|
|
|
for table in todo:
|
|
|
|
if table in self.schemas:
|
|
|
|
if table in self.schemas:
|
|
|
|
tables_todo.remove(table)
|
|
|
|
todo.remove(table)
|
|
|
|
return tables_todo
|
|
|
|
return todo
|
|
|
|
|
|
|
|
|
|
|
|
def _get_schema_for_tables(self, table_names: List[str]) -> str:
|
|
|
|
def _get_schema_for_tables(self, table_names: List[str]) -> str:
|
|
|
|
"""Create a string of the table schemas for the supplied tables."""
|
|
|
|
"""Create a string of the table schemas for the supplied tables."""
|
|
|
@ -135,19 +138,20 @@ class PowerBIDataset(BaseModel):
|
|
|
|
tables_requested = self._get_tables_to_query(table_names)
|
|
|
|
tables_requested = self._get_tables_to_query(table_names)
|
|
|
|
tables_todo = self._get_tables_todo(tables_requested)
|
|
|
|
tables_todo = self._get_tables_todo(tables_requested)
|
|
|
|
for table in tables_todo:
|
|
|
|
for table in tables_todo:
|
|
|
|
|
|
|
|
if " " in table and not table.startswith("'") and not table.endswith("'"):
|
|
|
|
|
|
|
|
table = f"'{table}'"
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
result = self.run(
|
|
|
|
result = self.run(
|
|
|
|
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
|
|
|
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except Timeout:
|
|
|
|
except Timeout:
|
|
|
|
_LOGGER.warning("Timeout while getting table info for %s", table)
|
|
|
|
_LOGGER.warning("Timeout while getting table info for %s", table)
|
|
|
|
|
|
|
|
self.schemas[table] = "unknown"
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
|
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
|
|
if "bad request" in str(exc).lower():
|
|
|
|
_LOGGER.warning("Error while getting table info for %s: %s", table, exc)
|
|
|
|
return SCHEMA_ERROR_RESPONSE
|
|
|
|
self.schemas[table] = "unknown"
|
|
|
|
if "unauthorized" in str(exc).lower():
|
|
|
|
continue
|
|
|
|
return UNAUTHORIZED_RESPONSE
|
|
|
|
|
|
|
|
return str(exc)
|
|
|
|
|
|
|
|
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
|
|
|
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
|
|
|
return self._get_schema_for_tables(tables_requested)
|
|
|
|
return self._get_schema_for_tables(tables_requested)
|
|
|
|
|
|
|
|
|
|
|
@ -158,19 +162,20 @@ class PowerBIDataset(BaseModel):
|
|
|
|
tables_requested = self._get_tables_to_query(table_names)
|
|
|
|
tables_requested = self._get_tables_to_query(table_names)
|
|
|
|
tables_todo = self._get_tables_todo(tables_requested)
|
|
|
|
tables_todo = self._get_tables_todo(tables_requested)
|
|
|
|
for table in tables_todo:
|
|
|
|
for table in tables_todo:
|
|
|
|
|
|
|
|
if " " in table and not table.startswith("'") and not table.endswith("'"):
|
|
|
|
|
|
|
|
table = f"'{table}'"
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
result = await self.arun(
|
|
|
|
result = await self.arun(
|
|
|
|
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
|
|
|
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except ServerTimeoutError:
|
|
|
|
except ServerTimeoutError:
|
|
|
|
_LOGGER.warning("Timeout while getting table info for %s", table)
|
|
|
|
_LOGGER.warning("Timeout while getting table info for %s", table)
|
|
|
|
|
|
|
|
self.schemas[table] = "unknown"
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
|
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
|
|
if "bad request" in str(exc).lower():
|
|
|
|
_LOGGER.warning("Error while getting table info for %s: %s", table, exc)
|
|
|
|
return SCHEMA_ERROR_RESPONSE
|
|
|
|
self.schemas[table] = "unknown"
|
|
|
|
if "unauthorized" in str(exc).lower():
|
|
|
|
continue
|
|
|
|
return UNAUTHORIZED_RESPONSE
|
|
|
|
|
|
|
|
return str(exc)
|
|
|
|
|
|
|
|
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
|
|
|
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
|
|
|
return self._get_schema_for_tables(tables_requested)
|
|
|
|
return self._get_schema_for_tables(tables_requested)
|
|
|
|
|
|
|
|
|
|
|
|