causalnex.inference.InferenceEngine

class causalnex.inference.InferenceEngine(bn)[source]

Bases: object

An InferenceEngine provides methods to query marginals based on observations and make interventions (Do-Calculus) on a BayesianNetwork.

Example:

 # Create a Bayesian Network with a manually defined DAG
 from causalnex.structure.structuremodel import StructureModel
 from causalnex.network import BayesianNetwork
 from causalnex.inference import InferenceEngine

 sm = StructureModel()
 sm.add_edges_from([
                    ('rush_hour', 'traffic'),
                    ('weather', 'traffic')
                    ])
 data = pd.DataFrame({
                      'rush_hour': [True, False, False, False, True, False, True],
                      'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'],
                      'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy']
                      })
 bn = BayesianNetwork(sm)
 # Inference can only be performed on the `BayesianNetwork` with learned nodes states and CPDs
 bn = bn.fit_node_states_and_cpds(data)

 # Create an `InferenceEngine` to query marginals and make interventions
 ie = InferenceEngine(bn)
 # Query the marginals as learned from data
 ie.query()['traffic']
{'heavy': 0.7142857142857142, 'light': 0.2857142857142857}
 # Query the marginals given observations
 ie.query({'rush_hour': True, 'weather': 'Terrible'})['traffic']
{'heavy': 1.0, 'light': 0.0}
 # Make an intervention on the `BayesianNetwork`
 ie.do_intervention('rush_hour', False)
 # Query marginals on the intervened `BayesianNetwork`
 ie.query()['traffic']
{'heavy': 0.5, 'light': 0.5}
 # Reset interventions
 ie.reset_do('rush_hour')
 ie.query()['traffic']
{'heavy': 0.7142857142857142, 'light': 0.2857142857142857}

Methods

InferenceEngine.__init__(bn) Create a new InferenceEngine from an existing BayesianNetwork.
InferenceEngine.do_intervention(node[, state]) Make an intervention on the Bayesian Network.
InferenceEngine.query([observations]) Query the BayesianNetwork for marginals given some observations.
InferenceEngine.reset_do(observation) Resets any do_interventions that have been applied to the observation.
__init__(bn)[source]

Create a new InferenceEngine from an existing BayesianNetwork.

It is expected that structure and probability distribution has already been learned for the BayesianNetwork that is to be used for inference. This Bayesian Network cannot contain any isolated nodes.

Parameters:bn (BayesianNetwork) – Bayesian Network that inference will act on.
Raises:ValueError – if the Bayesian Network contains isolates, or if a variable name is invalid, or if the CPDs have not been learned yet.
do_intervention(node, state=None)[source]

Make an intervention on the Bayesian Network.

For instance,
do_intervention(‘X’, ‘x’) will set \(P(X=x)\) to 1, and \(P(X=y)\) to 0 do_intervention(‘X’, {‘x’: 0.2, ‘y’: 0.8}) will set \(P(X=x)\) to 0.2, and \(P(X=y)\) to 0.8
Parameters:
  • node (str) – the node that the intervention acts upon.
  • state (Union[Hashable, Dict[Hashable, float]]) – state to update node it. - if Hashable: the intervention updates the state to 1, and all other states to 0; - if Dict[Hashable, float]: update states to all state -> probabilitiy in the dict.
Raises:

ValueError – if performing intervention would create an isolated node.

Return type:

None

query(observations=None)[source]

Query the BayesianNetwork for marginals given some observations.

Parameters:observations (Optional[Dict[str, Hashable]]) – observed states of nodes in the Bayesian Network. For instance, query({“node_a”: 1, “node_b”: 3}) If None or {}, the marginals for all nodes in the BayesianNetwork are returned.
Return type:Dict[str, Dict[Hashable, float]]
Returns:A dictionary of marginal probabilities of the network. For instance, \(P(a=1) = 0.3, P(a=2) = 0.7\) -> {a: {1: 0.3, 2: 0.7}}
reset_do(observation)[source]

Resets any do_interventions that have been applied to the observation.

Parameters:observation (str) – observation that will be reset.
Return type:None