from __future__ import annotations # allows pydantic model to reference itself import re from typing import Any, Optional, Union import duckdb import pandas as pd from langchain.graphs.networkx_graph import NetworkxEntityGraph from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator from langchain_experimental.cpal.constants import Constant class NarrativeModel(BaseModel): """ Represent the narrative input as three story elements. """ story_outcome_question: str story_hypothetical: str story_plot: str # causal stack of operations @validator("*", pre=True) def empty_str_to_none(cls, v: str) -> Union[str, None]: """Empty strings are not allowed""" if v == "": return None return v class EntityModel(BaseModel): name: str = Field(description="entity name") code: str = Field(description="entity actions") value: float = Field(description="entity initial value") depends_on: list[str] = Field(default=[], description="ancestor entities") # TODO: generalize to multivariate math # TODO: acyclic graph class Config: validate_assignment = True @validator("name") def lower_case_name(cls, v: str) -> str: v = v.lower() return v class CausalModel(BaseModel): attribute: str = Field(description="name of the attribute to be calculated") entities: list[EntityModel] = Field(description="entities in the story") # TODO: root validate each `entity.depends_on` using system's entity names class EntitySettingModel(BaseModel): """ Initial conditions for an entity {"name": "bud", "attribute": "pet_count", "value": 12} """ name: str = Field(description="name of the entity") attribute: str = Field(description="name of the attribute to be calculated") value: float = Field(description="entity's attribute value (calculated)") @validator("name") def lower_case_transform(cls, v: str) -> str: v = v.lower() return v class SystemSettingModel(BaseModel): """ Initial global conditions for the system. {"parameter": "interest_rate", "value": .05} """ parameter: str value: float class InterventionModel(BaseModel): """ aka initial conditions >>> intervention.dict() { entity_settings: [ {"name": "bud", "attribute": "pet_count", "value": 12}, {"name": "pat", "attribute": "pet_count", "value": 0}, ], system_settings: None, } """ entity_settings: list[EntitySettingModel] system_settings: Optional[list[SystemSettingModel]] = None @validator("system_settings") def lower_case_name(cls, v: str) -> Union[str, None]: if v is not None: raise NotImplementedError("system_setting is not implemented yet") return v class QueryModel(BaseModel): """translate a question about the story outcome into a programmatic expression""" question: str = Field(alias=Constant.narrative_input.value) # input expression: str # output, part of llm completion llm_error_msg: str # output, part of llm completion _result_table: str = PrivateAttr() # result of the executed query class ResultModel(BaseModel): question: str = Field(alias=Constant.narrative_input.value) # input _result_table: str = PrivateAttr() # result of the executed query class StoryModel(BaseModel): causal_operations: Any = Field(required=True) intervention: Any = Field(required=True) query: Any = Field(required=True) _outcome_table: pd.DataFrame = PrivateAttr(default=None) _networkx_wrapper: Any = PrivateAttr(default=None) def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._compute() # TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__` # misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214 @root_validator def check_intervention_is_valid(cls, values: dict) -> dict: valid_names = [e.name for e in values["causal_operations"].entities] for setting in values["intervention"].entity_settings: if setting.name not in valid_names: error_msg = f""" Hypothetical question has an invalid entity name. `{setting.name}` not in `{valid_names}` """ raise ValueError(error_msg) return values def _block_back_door_paths(self) -> None: # stop intervention entities from depending on others intervention_entities = [ entity_setting.name for entity_setting in self.intervention.entity_settings ] for entity in self.causal_operations.entities: if entity.name in intervention_entities: entity.depends_on = [] entity.code = "pass" def _set_initial_conditions(self) -> None: for entity_setting in self.intervention.entity_settings: for entity in self.causal_operations.entities: if entity.name == entity_setting.name: entity.value = entity_setting.value def _make_graph(self) -> None: self._networkx_wrapper = NetworkxEntityGraph() for entity in self.causal_operations.entities: for parent_name in entity.depends_on: self._networkx_wrapper._graph.add_edge( parent_name, entity.name, relation=entity.code ) # TODO: is it correct to drop entities with no impact on the outcome (?) self.causal_operations.entities = [ entity for entity in self.causal_operations.entities if entity.name in self._networkx_wrapper.get_topological_sort() ] def _sort_entities(self) -> None: # order the sequence of causal actions sorted_nodes = self._networkx_wrapper.get_topological_sort() self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name)) def _forward_propagate(self) -> None: entity_scope = { entity.name: entity for entity in self.causal_operations.entities } for entity in self.causal_operations.entities: if entity.code == "pass": continue else: # gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92 exec(entity.code, globals(), entity_scope) row_values = [entity.dict() for entity in entity_scope.values()] self._outcome_table = pd.DataFrame(row_values) def _run_query(self) -> None: def humanize_sql_error_msg(error: str) -> str: pattern = r"column\s+(.*?)\s+not found" col_match = re.search(pattern, error) if col_match: return ( "SQL error: " + col_match.group(1) + " is not an attribute in your story!" ) else: return str(error) if self.query.llm_error_msg == "": try: df = self._outcome_table # noqa query_result = duckdb.sql(self.query.expression).df() self.query._result_table = query_result except duckdb.BinderException as e: self.query._result_table = humanize_sql_error_msg(str(e)) except Exception as e: self.query._result_table = str(e) else: msg = "LLM maybe failed to translate question to SQL query." raise ValueError( { "question": self.query.question, "llm_error_msg": self.query.llm_error_msg, "msg": msg, } ) def _compute(self) -> Any: self._block_back_door_paths() self._set_initial_conditions() self._make_graph() self._sort_entities() self._forward_propagate() self._run_query() def print_debug_report(self) -> None: report = { "outcome": self._outcome_table, "query": self.query.dict(), "result": self.query._result_table, } from pprint import pprint pprint(report)