causalnex.evaluation.classification_report

causalnex.evaluation.classification_report(bn, data, node)[source]

Build a report showing the main classification metrics.

Parameters:
  • bn (BayesianNetwork) – model to compute classification report using.
  • data (pd.DataFrame) – test data that will be used for predictions.
  • node (str) – name of the variable to generate report for.
Return type:

Dict[~KT, ~VT]

Returns:

Text summary of the precision, recall, F1 score for each class.

The reported averages include micro average (averaging the total true positives, false negatives and false positives), macro average (averaging the unweighted mean per label), weighted average (averaging the support-weighted mean per label) and sample average (only for multilabel classification).

Note that in binary classification, recall of the positive class is also known as “sensitivity”; recall of the negative class is “specificity”.

Example:

 from causalnex.structure import StructureModel
 from causalnex.network import BayesianNetwork

 sm = StructureModel()
 sm.add_edges_from([
                    ('rush_hour', 'traffic'),
                    ('weather', 'traffic')
                    ])
 bn = BayesianNetwork(sm)
 import pandas as pd
 data = pd.DataFrame({
                      'rush_hour': [True, False, False, False, True, False, True],
                      'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'],
                      'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy']
                      }
 bn = bn.fit_node_states_and_cpds(data)
 test_data = pd.DataFrame({
                         'rush_hour': [False, False, True, True],
                         'weather': ['Good', 'Bad', 'Good', 'Bad'],
                         'traffic': ['light', 'heavy', 'heavy', 'light']
                         })
 from causalnex.evaluation import classification_report
 classification_report(bn, test_data, "traffic")
{'precision': {
    'macro avg': 0.8333333333333333, 'micro avg': 0.75,
    'traffic_heavy': 0.6666666666666666,
    'traffic_light': 1.0,
    'weighted avg': 0.8333333333333333
  },
 'recall': {
    'macro avg': 0.75,
    'micro avg': 0.75,
    'traffic_heavy': 1.0,
    'traffic_light': 0.5,
    'weighted avg': 0.75
  },
 'f1-score': {
    'macro avg': 0.7333333333333334,
    'micro avg': 0.75,
    'traffic_heavy': 0.8,
    'traffic_light': 0.6666666666666666,
    'weighted avg': 0.7333333333333334
  },
 'support': {
    'macro avg': 4,
    'micro avg': 4,
    'traffic_heavy': 2,
    'traffic_light': 2,
    'weighted avg': 4
  }}