causalnex.evaluation.classification_report¶
-
causalnex.evaluation.
classification_report
(bn, data, node)[source]¶ Build a report showing the main classification metrics.
- Parameters
bn (BayesianNetwork) – model to compute classification report using.
data (pd.DataFrame) – test data that will be used for predictions.
node (str) – name of the variable to generate report for.
- Return type
Dict
- Returns
Text summary of the precision, recall, F1 score for each class.
The reported averages include micro average (averaging the total true positives, false negatives and false positives), macro average (averaging the unweighted mean per label), weighted average (averaging the support-weighted mean per label) and sample average (only for multilabel classification).
Note that in binary classification, recall of the positive class is also known as “sensitivity”; recall of the negative class is “specificity”.
Example:
from causalnex.structure import StructureModel from causalnex.network import BayesianNetwork sm = StructureModel() sm.add_edges_from([ ('rush_hour', 'traffic'), ('weather', 'traffic') ]) bn = BayesianNetwork(sm) import pandas as pd data = pd.DataFrame({ 'rush_hour': [True, False, False, False, True, False, True], 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] } bn = bn.fit_node_states_and_cpds(data) test_data = pd.DataFrame({ 'rush_hour': [False, False, True, True], 'weather': ['Good', 'Bad', 'Good', 'Bad'], 'traffic': ['light', 'heavy', 'heavy', 'light'] }) from causalnex.evaluation import classification_report classification_report(bn, test_data, "traffic") {'precision': { 'macro avg': 0.8333333333333333, 'micro avg': 0.75, 'traffic_heavy': 0.6666666666666666, 'traffic_light': 1.0, 'weighted avg': 0.8333333333333333 }, 'recall': { 'macro avg': 0.75, 'micro avg': 0.75, 'traffic_heavy': 1.0, 'traffic_light': 0.5, 'weighted avg': 0.75 }, 'f1-score': { 'macro avg': 0.7333333333333334, 'micro avg': 0.75, 'traffic_heavy': 0.8, 'traffic_light': 0.6666666666666666, 'weighted avg': 0.7333333333333334 }, 'support': { 'macro avg': 4, 'micro avg': 4, 'traffic_heavy': 2, 'traffic_light': 2, 'weighted avg': 4 }}