causalnex.inference.InferenceEngine

class causalnex.inference.InferenceEngine(bn)[source]

Bases: object

An InferenceEngine provides methods to query marginals based on observations and make interventions (Do-Calculus) on a BayesianNetwork.

Example:

 # Create a Bayesian Network with a manually defined DAG
 from causalnex.structure.structuremodel import StructureModel
 from causalnex.network import BayesianNetwork
 from causalnex.inference import InferenceEngine

 sm = StructureModel()
 sm.add_edges_from([
                    ('rush_hour', 'traffic'),
                    ('weather', 'traffic')
                    ])
 data = pd.DataFrame({
                      'rush_hour': [True, False, False, False, True, False, True],
                      'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'],
                      'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy']
                      })
 bn = BayesianNetwork(sm)
 # Inference can only be performed on the `BayesianNetwork` with learned nodes states and CPDs
 bn = bn.fit_node_states_and_cpds(data)

 # Create an `InferenceEngine` to query marginals and make interventions
 ie = InferenceEngine(bn)
 # Query the marginals as learned from data
 ie.query()['traffic']
{'heavy': 0.7142857142857142, 'light': 0.2857142857142857}
 # Query the marginals given observations
 ie.query({'rush_hour': True, 'weather': 'Terrible'})['traffic']
{'heavy': 1.0, 'light': 0.0}
 # Make an intervention on the `BayesianNetwork`
 ie.do_intervention('rush_hour', False)
 # Query marginals on the intervened `BayesianNetwork`
 ie.query()['traffic']
{'heavy': 0.5, 'light': 0.5}
 # Reset interventions
 ie.reset_do('rush_hour')
 ie.query()['traffic']
{'heavy': 0.7142857142857142, 'light': 0.2857142857142857}

Attributes

Methods

InferenceEngine.__delattr__(name, /)

Implement delattr(self, name).

InferenceEngine.__dir__()

Default dir() implementation.

InferenceEngine.__eq__(value, /)

Return self==value.

InferenceEngine.__format__(format_spec, /)

Default object formatter.

InferenceEngine.__ge__(value, /)

Return self>=value.

InferenceEngine.__getattribute__(name, /)

Return getattr(self, name).

InferenceEngine.__gt__(value, /)

Return self>value.

InferenceEngine.__hash__()

Return hash(self).

InferenceEngine.__init__(bn)

Creates a new InferenceEngine from an existing BayesianNetwork.

InferenceEngine.__init_subclass__

This method is called when a class is subclassed.

InferenceEngine.__le__(value, /)

Return self<=value.

InferenceEngine.__lt__(value, /)

Return self<value.

InferenceEngine.__ne__(value, /)

Return self!=value.

InferenceEngine.__new__(**kwargs)

Create and return a new object.

InferenceEngine.__reduce__()

Helper for pickle.

InferenceEngine.__reduce_ex__(protocol, /)

Helper for pickle.

InferenceEngine.__repr__()

Return repr(self).

InferenceEngine.__setattr__(name, value, /)

Implement setattr(self, name, value).

InferenceEngine.__sizeof__()

Size of object in memory, in bytes.

InferenceEngine.__str__()

Return str(self).

InferenceEngine.__subclasshook__

Abstract classes can override this to customize issubclass().

InferenceEngine._create_cpds_dict_bn(bn)

Maps CPDs in the BayesianNetwork to required format:

InferenceEngine._create_node_function(name, args)

Creates a new function that describes a node in the BayesianNetwork.

InferenceEngine._create_node_functions()

Creates all functions required to create a BayesianNetwork.

InferenceEngine._do(observation, state)

Makes an intervention on the Bayesian Network.

InferenceEngine._generate_bbn()

Re-creates the _bbn.

InferenceEngine._generate_domains_bn(bn)

Generates domains from Bayesian network

InferenceEngine._remove_disconnected_nodes(var)

Identifies and removes from the _cpds the nodes of the bbn which are part of one or more upstream subgraphs that could have been formed after a do-intervention.

InferenceEngine._single_query([observations])

Queries the BayesianNetwork for marginals given some observations.

InferenceEngine.do_intervention(node[, state])

Makes an intervention on the Bayesian Network.

InferenceEngine.query([observations, …])

Queries the BayesianNetwork for marginals given one or more observations.

InferenceEngine.reset_do(observation)

Resets any do_interventions that have been applied to the observation.

__init__(bn)[source]

Creates a new InferenceEngine from an existing BayesianNetwork.

It is expected that structure and probability distribution has already been learned for the BayesianNetwork that is to be used for inference. This Bayesian Network cannot contain any isolated nodes.

Parameters

bn (BayesianNetwork) – Bayesian Network that inference will act on.

Raises

ValueError – if the Bayesian Network contains isolates, or if a variable name is invalid, or if the CPDs have not been learned yet.

do_intervention(node, state=None)[source]

Makes an intervention on the Bayesian Network.

For instance,

do_intervention(‘X’, ‘x’) will set \(P(X=x)\) to 1, and \(P(X=y)\) to 0 do_intervention(‘X’, {‘x’: 0.2, ‘y’: 0.8}) will set \(P(X=x)\) to 0.2, and \(P(X=y)\) to 0.8

Parameters
  • node (str) – the node that the intervention acts upon.

  • state (Union[Hashable, Dict[Hashable, float], None]) – state to update node it. - if Hashable: the intervention updates the state to 1, and all other states to 0; - if Dict[Hashable, float]: update states to all state -> probabilitiy in the dict.

Raises

ValueError – if performing intervention would create an isolated node.

query(observations=None, parallel=False, num_cores=None)[source]

Queries the BayesianNetwork for marginals given one or more observations.

Parameters
  • observations (Union[Dict[str, Any], List[Dict[str, Any]], None]) – one or more observations of states of nodes in the Bayesian Network.

  • parallel (bool) – if True, run the query using multiprocessing

  • num_cores (Optional[int]) – only applicable if parallel=True. The number of cores used during multiprocessing. If num_cores is not provided, number of processors will be autodetected and used

Return type

Union[Dict[str, Dict[Hashable, float]], List[Dict[str, Dict[Hashable, float]]]]

Returns

A dictionary or a list of dictionaries of marginal probabilities of the network.

Raises

TypeError – if observations is neither None nor a dictionary nor a list

reset_do(observation)[source]

Resets any do_interventions that have been applied to the observation.

Parameters

observation (str) – observation that will be reset.