Source code for causalnex.inference.inference

# Copyright 2019-2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
#     or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the implementation of ``InferenceEngine``.

``InferenceEngine`` provides tools to make inferences based on interventions and observations.
"""
import copy
import inspect
import math
import re
import types
from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union

import networkx as nx
import pandas as pd
from pathos import multiprocessing

from causalnex.ebaybbn import build_bbn
from causalnex.network import BayesianNetwork


[docs]class InferenceEngine: """ An ``InferenceEngine`` provides methods to query marginals based on observations and make interventions (Do-Calculus) on a ``BayesianNetwork``. Example: :: >>> # Create a Bayesian Network with a manually defined DAG >>> from causalnex.structure.structuremodel import StructureModel >>> from causalnex.network import BayesianNetwork >>> from causalnex.inference import InferenceEngine >>> >>> sm = StructureModel() >>> sm.add_edges_from([ >>> ('rush_hour', 'traffic'), >>> ('weather', 'traffic') >>> ]) >>> data = pd.DataFrame({ >>> 'rush_hour': [True, False, False, False, True, False, True], >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] >>> }) >>> bn = BayesianNetwork(sm) >>> # Inference can only be performed on the `BayesianNetwork` with learned nodes states and CPDs >>> bn = bn.fit_node_states_and_cpds(data) >>> >>> # Create an `InferenceEngine` to query marginals and make interventions >>> ie = InferenceEngine(bn) >>> # Query the marginals as learned from data >>> ie.query()['traffic'] {'heavy': 0.7142857142857142, 'light': 0.2857142857142857} >>> # Query the marginals given observations >>> ie.query({'rush_hour': True, 'weather': 'Terrible'})['traffic'] {'heavy': 1.0, 'light': 0.0} >>> # Make an intervention on the `BayesianNetwork` >>> ie.do_intervention('rush_hour', False) >>> # Query marginals on the intervened `BayesianNetwork` >>> ie.query()['traffic'] {'heavy': 0.5, 'light': 0.5} >>> # Reset interventions >>> ie.reset_do('rush_hour') >>> ie.query()['traffic'] {'heavy': 0.7142857142857142, 'light': 0.2857142857142857} """
[docs] def __init__(self, bn: BayesianNetwork): """ Creates a new ``InferenceEngine`` from an existing ``BayesianNetwork``. It is expected that structure and probability distribution has already been learned for the ``BayesianNetwork`` that is to be used for inference. This Bayesian Network cannot contain any isolated nodes. Args: bn: Bayesian Network that inference will act on. Raises: ValueError: if the Bayesian Network contains isolates, or if a variable name is invalid, or if the CPDs have not been learned yet. """ bad_nodes = [node for node in bn.nodes if not re.match("^[0-9a-zA-Z_]+$", node)] if bad_nodes: raise ValueError( "Variable names must match ^[0-9a-zA-Z_]+$ - please fix the " f"following nodes: {bad_nodes}" ) if not bn.cpds: raise ValueError( "Bayesian Network does not contain any CPDs. You should fit CPDs " "before doing inference (see `BayesianNetwork.fit_cpds`)." ) self._cpds = None self._detached_cpds = {} self._baseline_marginals = None self._create_cpds_dict_bn(bn) self._generate_domains_bn(bn) self._generate_bbn()
def _single_query( self, observations: Optional[Dict[str, Any]] = None, ) -> Dict[str, Dict[Hashable, float]]: """ Queries the ``BayesianNetwork`` for marginals given some observations. Args: observations: observed states of nodes in the Bayesian Network. For instance, query({"node_a": 1, "node_b": 3}) If None or {}, the marginals for all nodes in the ``BayesianNetwork`` are returned. Returns: A dictionary of marginal probabilities of the network. For instance, :math:`P(a=1) = 0.3, P(a=2) = 0.7` -> {a: {1: 0.3, 2: 0.7}} """ bbn_results = ( self._bbn.query(**observations) if observations else self._bbn.query() ) results = {node: {} for node in self._cpds} for (node, state), prob in bbn_results.items(): results[node][state] = prob # the detached nodes are set to the baseline marginals based on original CPDs for node in self._detached_cpds: results[node] = self._baseline_marginals[node] return results
[docs] def query( self, observations: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, parallel: bool = False, num_cores: Optional[int] = None, ) -> Union[ Dict[str, Dict[Hashable, float]], List[Dict[str, Dict[Hashable, float]]], ]: """ Queries the ``BayesianNetwork`` for marginals given one or more observations. Args: observations: one or more observations of states of nodes in the Bayesian Network. parallel: if True, run the query using multiprocessing num_cores: only applicable if parallel=True. The number of cores used during multiprocessing. If num_cores is not provided, number of processors will be autodetected and used Returns: A dictionary or a list of dictionaries of marginal probabilities of the network. Raises: TypeError: if observations is neither None nor a dictionary nor a list """ if observations is not None and not isinstance(observations, (dict, list)): raise TypeError("Expecting observations to be a dict, list or None") if isinstance(observations, list): if parallel: with multiprocessing.Pool(num_cores) as p: result = p.map(self._single_query, observations) else: result = [self._single_query(obs) for obs in observations] else: # dictionary or None result = self._single_query(observations) return result
def _do(self, observation: str, state: Dict[Hashable, float]): """ Makes an intervention on the Bayesian Network. Args: observation: observation that the intervention is on. state: mapping of state -> probability. Raises: ValueError: if states do not match original states of the node, or probabilities do not sum to 1. """ if not math.isclose(sum(state.values()), 1.0): raise ValueError("The cpd for the provided observation must sum to 1") if max(state.values()) > 1.0 or min(state.values()) < 0: raise ValueError( "The cpd for the provided observation must be between 0 and 1" ) if not set(state.keys()) == set(self._cpds_original[observation]): expected = set(self._cpds_original[observation]) found = set(state.keys()) raise ValueError( f"The cpd states do not match expected states: expected {expected}, found {found}" ) self._cpds[observation] = {s: {(): p} for s, p in state.items()}
[docs] def do_intervention( self, node: str, state: Optional[Union[Hashable, Dict[Hashable, float]]] = None, ): """ Makes an intervention on the Bayesian Network. For instance, `do_intervention('X', 'x')` will set :math:`P(X=x)` to 1, and :math:`P(X=y)` to 0 `do_intervention('X', {'x': 0.2, 'y': 0.8})` will set :math:`P(X=x)` to 0.2, and :math:`P(X=y)` to 0.8 Args: node: the node that the intervention acts upon. state: state to update node it. - if Hashable: the intervention updates the state to 1, and all other states to 0; - if Dict[Hashable, float]: update states to all state -> probabilitiy in the dict. Raises: ValueError: if performing intervention would create an isolated node. """ if not any( node in inspect.getargs(f.__code__)[0][1:] for _, f in self._node_functions.items() ): raise ValueError( "Do calculus cannot be applied because it would result in an isolate" ) # initialise baseline marginals if not done previously if self._baseline_marginals is None: self._baseline_marginals = self._single_query(None) if isinstance(state, int): state = {s: float(s == state) for s in self._cpds[node]} self._do(node, state) # check for presence of separate subgraph after do-intervention self._remove_disconnected_nodes(node) self._generate_bbn()
[docs] def reset_do(self, observation: str): """ Resets any do_interventions that have been applied to the observation. Args: observation: observation that will be reset. """ self._cpds[observation] = self._cpds_original[observation] for node, cpd in self._detached_cpds.items(): self._cpds[node] = cpd self._detached_cpds = {} self._generate_bbn()
def _generate_bbn(self): """Re-creates the _bbn.""" self._node_functions = self._create_node_functions() self._bbn = build_bbn( list(self._node_functions.values()), domains=self._domains, ) def _generate_domains_bn(self, bn: BayesianNetwork): """ Generates domains from Bayesian network Args: bn: Bayesian network """ self._domains = { variable: list(cpd.index.values) for variable, cpd in bn.cpds.items() } def _create_cpds_dict_bn(self, bn: BayesianNetwork): """ Maps CPDs in the ``BayesianNetwork`` to required format: Args: bn: Bayesian network >>> {"observation": >>> {"state": >>> {(("condition1_observation", "condition1_state"), ("conditionN_observation", "conditionN_state")): >>> "probability" >>> } >>> } For example, :math:`P( Colour=red | Make=fender, Model=stratocaster) = 0.4`: >>> {"colour": >>> {"red": >>> {(("make", "fender"), ("model", "stratocaster")): >>> 0.4 >>> } >>> } >>> } """ lookup = { variable: { state: { tuple(zip(cpd.columns.names, parent_value)): cpd.loc[state][ parent_value ] for parent_value in pd.MultiIndex.from_frame(cpd).names } for state in cpd.index.values } for variable, cpd in bn.cpds.items() } self._cpds = lookup self._cpds_original = copy.deepcopy(self._cpds) def _create_node_function(self, name: str, args: Tuple[str]): """Creates a new function that describes a node in the ``BayesianNetwork``.""" def template() -> float: """Template node function.""" # use inspection to determine arguments to the function # initially there are none present, but caller will add appropriate arguments to the function # getargvalues was "inadvertently marked as deprecated in Python 3.5" # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec arg_spec = inspect.getargvalues(inspect.currentframe()) # pragma: no cover return self._cpds[arg_spec.args[0]][ # target name arg_spec.locals[arg_spec.args[0]] ][ # target state tuple((arg, arg_spec.locals[arg]) for arg in arg_spec.args[1:]) ] # conditions code = template.__code__ pos_count = ( [code.co_posonlyargcount] if hasattr(code, "co_posonlyargcount") else [] ) template.__code__ = types.CodeType( len(args), *pos_count, code.co_kwonlyargcount, len(args), code.co_stacksize, code.co_flags, code.co_code, code.co_consts, code.co_names, args, code.co_filename, name, code.co_firstlineno, code.co_lnotab, code.co_freevars, code.co_cellvars, ) template.__name__ = name return template def _create_node_functions(self) -> Dict[str, Callable]: """ Creates all functions required to create a ``BayesianNetwork``. Returns: Dictionary of node functions """ node_functions = {} for node, states in self._cpds.items(): # since we only need condition names, which are consistent across all states, # then we can inspect the 0th element states_conditions = next(iter(states.values())) # take any state, and get its conditions state_conditions = next(iter(states_conditions.keys())) condition_nodes = [n for n, v in state_conditions] node_args = tuple([node] + condition_nodes) # type: Tuple[str] node_function = self._create_node_function(f"f_{node}", node_args) node_functions[node] = node_function return node_functions def _remove_disconnected_nodes(self, var: str): """ Identifies and removes from the _cpds the nodes of the bbn which are part of one or more upstream subgraphs that could have been formed after a do-intervention. Uses the attribute _cpds to determine the parents of each node. Leverages networkX `weakly_connected_component` method to identify the subgraphs. For instance, the network A -> B -> C -> D -> E would be split into two sub networks (A -> B) and (C -> D -> E) if we intervene on node C. Args: var: variable we have intervened on """ # construct graph from CPDs g = nx.DiGraph() for node, states in self._cpds.items(): sample_state = next(iter(states.values())) parents = next(iter(sample_state.keys())) g.add_node(node) # add nodes as there could be isolates for parent, _ in parents: g.add_edge(parent, node) # remove nodes in subgraphs which do not contain the intervention node for sub_graph in nx.weakly_connected_components(g): if var not in sub_graph: for node in sub_graph: self._detached_cpds[node] = self._cpds[node] self._cpds.pop(node)