langchain/libs/experimental/langchain_experimental/cpal/models.py
2023-07-21 18:44:32 -07:00

246 lines
8.1 KiB
Python

from __future__ import annotations # allows pydantic model to reference itself
import re
from typing import Any, Optional, Union
import duckdb
import pandas as pd
from langchain.graphs.networkx_graph import NetworkxEntityGraph
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator
from langchain_experimental.cpal.constants import Constant
class NarrativeModel(BaseModel):
"""
Represent the narrative input as three story elements.
"""
story_outcome_question: str
story_hypothetical: str
story_plot: str # causal stack of operations
@validator("*", pre=True)
def empty_str_to_none(cls, v: str) -> Union[str, None]:
"""Empty strings are not allowed"""
if v == "":
return None
return v
class EntityModel(BaseModel):
name: str = Field(description="entity name")
code: str = Field(description="entity actions")
value: float = Field(description="entity initial value")
depends_on: list[str] = Field(default=[], description="ancestor entities")
# TODO: generalize to multivariate math
# TODO: acyclic graph
class Config:
validate_assignment = True
@validator("name")
def lower_case_name(cls, v: str) -> str:
v = v.lower()
return v
class CausalModel(BaseModel):
attribute: str = Field(description="name of the attribute to be calculated")
entities: list[EntityModel] = Field(description="entities in the story")
# TODO: root validate each `entity.depends_on` using system's entity names
class EntitySettingModel(BaseModel):
"""
Initial conditions for an entity
{"name": "bud", "attribute": "pet_count", "value": 12}
"""
name: str = Field(description="name of the entity")
attribute: str = Field(description="name of the attribute to be calculated")
value: float = Field(description="entity's attribute value (calculated)")
@validator("name")
def lower_case_transform(cls, v: str) -> str:
v = v.lower()
return v
class SystemSettingModel(BaseModel):
"""
Initial global conditions for the system.
{"parameter": "interest_rate", "value": .05}
"""
parameter: str
value: float
class InterventionModel(BaseModel):
"""
aka initial conditions
>>> intervention.dict()
{
entity_settings: [
{"name": "bud", "attribute": "pet_count", "value": 12},
{"name": "pat", "attribute": "pet_count", "value": 0},
],
system_settings: None,
}
"""
entity_settings: list[EntitySettingModel]
system_settings: Optional[list[SystemSettingModel]] = None
@validator("system_settings")
def lower_case_name(cls, v: str) -> Union[str, None]:
if v is not None:
raise NotImplementedError("system_setting is not implemented yet")
return v
class QueryModel(BaseModel):
"""translate a question about the story outcome into a programmatic expression"""
question: str = Field(alias=Constant.narrative_input.value) # input
expression: str # output, part of llm completion
llm_error_msg: str # output, part of llm completion
_result_table: str = PrivateAttr() # result of the executed query
class ResultModel(BaseModel):
question: str = Field(alias=Constant.narrative_input.value) # input
_result_table: str = PrivateAttr() # result of the executed query
class StoryModel(BaseModel):
causal_operations: Any = Field(required=True)
intervention: Any = Field(required=True)
query: Any = Field(required=True)
_outcome_table: pd.DataFrame = PrivateAttr(default=None)
_networkx_wrapper: Any = PrivateAttr(default=None)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._compute()
# TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__`
# misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214
@root_validator
def check_intervention_is_valid(cls, values: dict) -> dict:
valid_names = [e.name for e in values["causal_operations"].entities]
for setting in values["intervention"].entity_settings:
if setting.name not in valid_names:
error_msg = f"""
Hypothetical question has an invalid entity name.
`{setting.name}` not in `{valid_names}`
"""
raise ValueError(error_msg)
return values
def _block_back_door_paths(self) -> None:
# stop intervention entities from depending on others
intervention_entities = [
entity_setting.name for entity_setting in self.intervention.entity_settings
]
for entity in self.causal_operations.entities:
if entity.name in intervention_entities:
entity.depends_on = []
entity.code = "pass"
def _set_initial_conditions(self) -> None:
for entity_setting in self.intervention.entity_settings:
for entity in self.causal_operations.entities:
if entity.name == entity_setting.name:
entity.value = entity_setting.value
def _make_graph(self) -> None:
self._networkx_wrapper = NetworkxEntityGraph()
for entity in self.causal_operations.entities:
for parent_name in entity.depends_on:
self._networkx_wrapper._graph.add_edge(
parent_name, entity.name, relation=entity.code
)
# TODO: is it correct to drop entities with no impact on the outcome (?)
self.causal_operations.entities = [
entity
for entity in self.causal_operations.entities
if entity.name in self._networkx_wrapper.get_topological_sort()
]
def _sort_entities(self) -> None:
# order the sequence of causal actions
sorted_nodes = self._networkx_wrapper.get_topological_sort()
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
def _forward_propagate(self) -> None:
entity_scope = {
entity.name: entity for entity in self.causal_operations.entities
}
for entity in self.causal_operations.entities:
if entity.code == "pass":
continue
else:
# gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92
exec(entity.code, globals(), entity_scope)
row_values = [entity.dict() for entity in entity_scope.values()]
self._outcome_table = pd.DataFrame(row_values)
def _run_query(self) -> None:
def humanize_sql_error_msg(error: str) -> str:
pattern = r"column\s+(.*?)\s+not found"
col_match = re.search(pattern, error)
if col_match:
return (
"SQL error: "
+ col_match.group(1)
+ " is not an attribute in your story!"
)
else:
return str(error)
if self.query.llm_error_msg == "":
try:
df = self._outcome_table # noqa
query_result = duckdb.sql(self.query.expression).df()
self.query._result_table = query_result
except duckdb.BinderException as e:
self.query._result_table = humanize_sql_error_msg(str(e))
except Exception as e:
self.query._result_table = str(e)
else:
msg = "LLM maybe failed to translate question to SQL query."
raise ValueError(
{
"question": self.query.question,
"llm_error_msg": self.query.llm_error_msg,
"msg": msg,
}
)
def _compute(self) -> Any:
self._block_back_door_paths()
self._set_initial_conditions()
self._make_graph()
self._sort_entities()
self._forward_propagate()
self._run_query()
def print_debug_report(self) -> None:
report = {
"outcome": self._outcome_table,
"query": self.query.dict(),
"result": self.query._result_table,
}
from pprint import pprint
pprint(report)