@ -1,5 +1,5 @@
""" Tools for interacting with a Power BI dataset. """
from typing import Any , Dict , Optional
from typing import Any , Dict , Optional , Tuple
from pydantic import Field , validator
@ -11,9 +11,9 @@ from langchain.chains.llm import LLMChain
from langchain . tools . base import BaseTool
from langchain . tools . powerbi . prompt import (
BAD_REQUEST_RESPONSE ,
BAD_REQUEST_RESPONSE_ESCALATED ,
DEFAULT_FEWSHOT_EXAMPLES ,
QUESTION_TO_QUERY ,
RETRY_RESPONSE ,
)
from langchain . utilities . powerbi import PowerBIDataset , json_to_md
@ -23,21 +23,39 @@ class QueryPowerBITool(BaseTool):
name = " query_powerbi "
description = """
Input to this tool is a detailed and correct DAX query , output is a result from the dataset .
If the query is not correct , an error message will be returned .
If an error is returned with Bad request in it , rewrite the query and try again .
If an error is returned with Unauthorized in it , do not try again , but tell the user to change their authentication .
Input to this tool is a detailed question about the dataset , output is a result from the dataset . It will try to answer the question using the dataset , and if it cannot , it will ask for clarification .
Example Input : " EVALUATE ROW(" count " , COUNTROWS(table1)) "
Example Input : " How many rows are in table1? "
""" # noqa: E501
llm_chain : LLMChain
powerbi : PowerBIDataset = Field ( exclude = True )
template : Optional [ str ] = QUESTION_TO_QUERY
examples : Optional [ str ] = DEFAULT_FEWSHOT_EXAMPLES
session_cache : Dict [ str , Any ] = Field ( default_factory = dict , exclude = True )
max_iterations : int = 5
class Config :
""" Configuration for this pydantic object. """
arbitrary_types_allowed = True
@validator ( " llm_chain " )
def validate_llm_chain_input_variables ( # pylint: disable=E0213
cls , llm_chain : LLMChain
) - > LLMChain :
""" Make sure the LLM chain has the correct input variables. """
if llm_chain . prompt . input_variables != [
" tool_input " ,
" tables " ,
" schemas " ,
" examples " ,
] :
raise ValueError (
" LLM chain for QueryPowerBITool must have input variables [ ' tool_input ' , ' tables ' , ' schemas ' , ' examples ' ], found %s " , # noqa: C0301 E501 # pylint: disable=C0301
llm_chain . prompt . input_variables ,
)
return llm_chain
def _check_cache ( self , tool_input : str ) - > Optional [ str ] :
""" Check if the input is present in the cache.
@ -45,88 +63,106 @@ class QueryPowerBITool(BaseTool):
if not present return None . """
if tool_input not in self . session_cache :
return None
if self . session_cache [ tool_input ] == BAD_REQUEST_RESPONSE :
self . session_cache [ tool_input ] = BAD_REQUEST_RESPONSE_ESCALATED
return self . session_cache [ tool_input ]
def _run (
self ,
tool_input : str ,
run_manager : Optional [ CallbackManagerForToolRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Execute the query, return the results or an error message. """
if cache := self . _check_cache ( tool_input ) :
return cache
try :
self . session_cache [ tool_input ] = self . powerbi . run ( command = tool_input )
except Exception as exc : # pylint: disable=broad-except
if " bad request " in str ( exc ) . lower ( ) :
self . session_cache [ tool_input ] = BAD_REQUEST_RESPONSE
elif " unauthorized " in str ( exc ) . lower ( ) :
self . session_cache [
tool_input
] = " Unauthorized. Try changing your authentication, do not retry. "
else :
self . session_cache [ tool_input ] = str ( exc )
return self . session_cache [ tool_input ]
if " results " in self . session_cache [ tool_input ] :
self . session_cache [ tool_input ] = json_to_md (
self . session_cache [ tool_input ] [ " results " ] [ 0 ] [ " tables " ] [ 0 ] [ " rows " ]
query = self . llm_chain . predict (
tool_input = tool_input ,
tables = self . powerbi . get_table_names ( ) ,
schemas = self . powerbi . get_schemas ( ) ,
examples = self . examples ,
)
except Exception as exc : # pylint: disable=broad-except
self . session_cache [ tool_input ] = f " Error on call to LLM: { exc } "
return self . session_cache [ tool_input ]
if (
" error " in self . session_cache [ tool_input ]
and " pbi.error " in self . session_cache [ tool_input ] [ " error " ]
and " details " in self . session_cache [ tool_input ] [ " error " ] [ " pbi.error " ]
) :
self . session_cache [
tool_input
] = f ' { BAD_REQUEST_RESPONSE } Error was { self . session_cache [ tool_input ] [ " error " ] [ " pbi.error " ] [ " details " ] [ 0 ] [ " detail " ] } ' # noqa: E501
if query == " I cannot answer this " :
self . session_cache [ tool_input ] = query
return self . session_cache [ tool_input ]
self . session_cache [
tool_input
] = f ' { BAD_REQUEST_RESPONSE } Error was { self . session_cache [ tool_input ] [ " error " ] } ' # noqa: E501
pbi_result = self . powerbi . run ( command = query )
result , error = self . _parse_output ( pbi_result )
iterations = kwargs . get ( " iterations " , 0 )
if error and iterations < self . max_iterations :
return self . _run (
tool_input = RETRY_RESPONSE . format (
tool_input = tool_input , query = query , error = error
) ,
run_manager = run_manager ,
iterations = iterations + 1 ,
)
self . session_cache [ tool_input ] = (
result if result else BAD_REQUEST_RESPONSE . format ( error = error )
)
return self . session_cache [ tool_input ]
async def _arun (
self ,
tool_input : str ,
run_manager : Optional [ AsyncCallbackManagerForToolRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Execute the query, return the results or an error message. """
if cache := self . _check_cache ( tool_input ) :
return cache
try :
self . session_cache [ tool_input ] = await self . powerbi . arun ( command = tool_input )
except Exception as exc : # pylint: disable=broad-except
if " bad request " in str ( exc ) . lower ( ) :
self . session_cache [ tool_input ] = BAD_REQUEST_RESPONSE
elif " unauthorized " in str ( exc ) . lower ( ) :
self . session_cache [
tool_input
] = " Unauthorized. Try changing your authentication, do not retry. "
else :
self . session_cache [ tool_input ] = str ( exc )
return self . session_cache [ tool_input ]
if " results " in self . session_cache [ tool_input ] :
self . session_cache [ tool_input ] = json_to_md (
self . session_cache [ tool_input ] [ " results " ] [ 0 ] [ " tables " ] [ 0 ] [ " rows " ]
query = await self . llm_chain . apredict (
tool_input = tool_input ,
tables = self . powerbi . get_table_names ( ) ,
schemas = self . powerbi . get_schemas ( ) ,
examples = self . examples ,
)
except Exception as exc : # pylint: disable=broad-except
self . session_cache [ tool_input ] = f " Error on call to LLM: { exc } "
return self . session_cache [ tool_input ]
if (
" error " in self . session_cache [ tool_input ]
and " pbi.error " in self . session_cache [ tool_input ] [ " error " ]
and " details " in self . session_cache [ tool_input ] [ " error " ] [ " pbi.error " ]
) :
self . session_cache [
tool_input
] = f ' { BAD_REQUEST_RESPONSE } Error was { self . session_cache [ tool_input ] [ " error " ] [ " pbi.error " ] [ " details " ] [ 0 ] [ " detail " ] } ' # noqa: E501
if query == " I cannot answer this " :
self . session_cache [ tool_input ] = query
return self . session_cache [ tool_input ]
self . session_cache [
tool_input
] = f ' { BAD_REQUEST_RESPONSE } Error was { self . session_cache [ tool_input ] [ " error " ] } ' # noqa: E501
pbi_result = await self . powerbi . arun ( command = query )
result , error = self . _parse_output ( pbi_result )
iterations = kwargs . get ( " iterations " , 0 )
if error and iterations < self . max_iterations :
return await self . _arun (
tool_input = RETRY_RESPONSE . format (
tool_input = tool_input , query = query , error = error
) ,
run_manager = run_manager ,
iterations = iterations + 1 ,
)
self . session_cache [ tool_input ] = (
result if result else BAD_REQUEST_RESPONSE . format ( error = error )
)
return self . session_cache [ tool_input ]
def _parse_output (
self , pbi_result : Dict [ str , Any ]
) - > Tuple [ Optional [ str ] , Optional [ str ] ] :
""" Parse the output of the query to a markdown table. """
if " results " in pbi_result :
return json_to_md ( pbi_result [ " results " ] [ 0 ] [ " tables " ] [ 0 ] [ " rows " ] ) , None
if " error " in pbi_result :
if (
" pbi.error " in pbi_result [ " error " ]
and " details " in pbi_result [ " error " ] [ " pbi.error " ]
) :
return None , pbi_result [ " error " ] [ " pbi.error " ] [ " details " ] [ 0 ] [ " detail " ]
return None , pbi_result [ " error " ]
return None , " Unknown error "
class InfoPowerBITool ( BaseTool ) :
""" Tool for getting metadata about a PowerBI Dataset. """
@ -188,64 +224,3 @@ class ListPowerBITool(BaseTool):
) - > str :
""" Get the names of the tables. """
return " , " . join ( self . powerbi . get_table_names ( ) )
class InputToQueryTool ( BaseTool ) :
""" Use an LLM to parse the question to a DAX query. """
name = " question_to_query_powerbi "
description = """
Use this tool to create the DAX query from a question , the input is a fully formed question related to the powerbi dataset . Always use this tool before executing a query with query_powerbi !
Example Input : " How many records are in table1? "
""" # noqa: E501
llm_chain : LLMChain
powerbi : PowerBIDataset = Field ( exclude = True )
template : Optional [ str ] = QUESTION_TO_QUERY
examples : Optional [ str ] = DEFAULT_FEWSHOT_EXAMPLES
class Config :
""" Configuration for this pydantic object. """
arbitrary_types_allowed = True
@validator ( " llm_chain " )
def validate_llm_chain_input_variables ( # pylint: disable=E0213
cls , llm_chain : LLMChain
) - > LLMChain :
""" Make sure the LLM chain has the correct input variables. """
if llm_chain . prompt . input_variables != [
" tool_input " ,
" tables " ,
" schemas " ,
" examples " ,
] :
raise ValueError (
" LLM chain for InputToQueryTool must have input variables [ ' tool_input ' , ' tables ' , ' schemas ' , ' examples ' ] " # noqa: C0301 E501 # pylint: disable=C0301
)
return llm_chain
def _run (
self ,
tool_input : str ,
run_manager : Optional [ CallbackManagerForToolRun ] = None ,
) - > str :
""" Use the LLM to check the query. """
return self . llm_chain . predict (
tool_input = tool_input ,
tables = self . powerbi . get_table_names ( ) ,
schemas = self . powerbi . get_schemas ( ) ,
examples = self . examples ,
)
async def _arun (
self ,
tool_input : str ,
run_manager : Optional [ AsyncCallbackManagerForToolRun ] = None ,
) - > str :
return await self . llm_chain . apredict (
tool_input = tool_input ,
tables = self . powerbi . get_table_names ( ) ,
schemas = self . powerbi . get_schemas ( ) ,
examples = self . examples ,
)