{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "from xai_agg import *\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score, roc_auc_score\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from IPython.core.display import display, HTML\n", "\n", "def display_side_by_side(dfs: list[pd.DataFrame], captions: list[str] = []):\n", " \"\"\"Display tables side by side to save vertical space\n", " Input:\n", " dfs: list of pandas.DataFrame\n", " captions: list of table captions\n", " \"\"\"\n", " output = \"\"\n", " for i, df in enumerate(dfs):\n", " caption = captions[i] if i < len(captions) else \"\"\n", " \n", " output += df.style.set_table_attributes(\"style='display:inline'\").set_caption(f\"{caption}\")._repr_html_()\n", " output += \"\\xa0\\xa0\\xa0\"\n", " display(HTML(output))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Preprocess the data\n", "1. One-hot-encode categorical variables, making sure the one-hot-encoded column names are in the format \"[FEATURE]_[CATEGORY]\"\n", "2. Make sure all column names are valid python identifiers" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0AgeSexJobHousingSaving accountsChecking accountCredit amountDurationPurposeCredit Risk
0067male2ownNaNlittle11696radio/TV1
1122female2ownlittlemoderate595148radio/TV2
2249male1ownlittleNaN209612education1
3345male2freelittlelittle788242furniture/equipment1
4453male2freelittlelittle487024car2
\n", "
" ], "text/plain": [ " Unnamed: 0 Age Sex Job Housing Saving accounts Checking account \\\n", "0 0 67 male 2 own NaN little \n", "1 1 22 female 2 own little moderate \n", "2 2 49 male 1 own little NaN \n", "3 3 45 male 2 free little little \n", "4 4 53 male 2 free little little \n", "\n", " Credit amount Duration Purpose Credit Risk \n", "0 1169 6 radio/TV 1 \n", "1 5951 48 radio/TV 2 \n", "2 2096 12 education 1 \n", "3 7882 42 furniture/equipment 1 \n", "4 4870 24 car 2 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0AgeJobCredit amountDurationCredit Risk
count954.000000954.000000954.000000954.000000954.000000954.000000
mean476.50000035.5010481.9098533279.11215920.7809221.302935
std275.54037811.3796680.6496812853.31515812.0464830.459768
min0.00000019.0000000.000000250.0000004.0000001.000000
25%238.25000027.0000002.0000001360.25000012.0000001.000000
50%476.50000033.0000002.0000002302.50000018.0000001.000000
75%714.75000042.0000002.0000003975.25000024.0000002.000000
max953.00000075.0000003.00000018424.00000072.0000002.000000
\n", "
" ], "text/plain": [ " Unnamed: 0 Age Job Credit amount Duration \\\n", "count 954.000000 954.000000 954.000000 954.000000 954.000000 \n", "mean 476.500000 35.501048 1.909853 3279.112159 20.780922 \n", "std 275.540378 11.379668 0.649681 2853.315158 12.046483 \n", "min 0.000000 19.000000 0.000000 250.000000 4.000000 \n", "25% 238.250000 27.000000 2.000000 1360.250000 12.000000 \n", "50% 476.500000 33.000000 2.000000 2302.500000 18.000000 \n", "75% 714.750000 42.000000 2.000000 3975.250000 24.000000 \n", "max 953.000000 75.000000 3.000000 18424.000000 72.000000 \n", "\n", " Credit Risk \n", "count 954.000000 \n", "mean 1.302935 \n", "std 0.459768 \n", "min 1.000000 \n", "25% 1.000000 \n", "50% 1.000000 \n", "75% 2.000000 \n", "max 2.000000 " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 954 entries, 0 to 953\n", "Data columns (total 11 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Unnamed: 0 954 non-null int64 \n", " 1 Age 954 non-null int64 \n", " 2 Sex 954 non-null object\n", " 3 Job 954 non-null int64 \n", " 4 Housing 954 non-null object\n", " 5 Saving accounts 779 non-null object\n", " 6 Checking account 576 non-null object\n", " 7 Credit amount 954 non-null int64 \n", " 8 Duration 954 non-null int64 \n", " 9 Purpose 954 non-null object\n", " 10 Credit Risk 954 non-null int64 \n", "dtypes: int64(6), object(5)\n", "memory usage: 82.1+ KB\n" ] }, { "data": { "text/plain": [ "None" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Unique values of the categorical features:\n", "\t- Sex: ['male' 'female']\n", "\t- Housing: ['own' 'free' 'rent']\n", "\t- Saving accounts: [nan 'little' 'quite rich' 'rich' 'moderate']\n", "\t- Checking account: ['little' 'moderate' nan 'rich']\n", "\t- Purpose: ['radio/TV' 'education' 'furniture/equipment' 'car' 'business'\n", " 'domestic appliances' 'repairs' 'vacation/others']\n" ] } ], "source": [ "original_data = pd.read_csv('../data/german_credit_data_updated.csv')\n", "\n", "# Dataset overview - German Credit Risk (from Kaggle):\n", "# 1. Age (numeric)\n", "# 2. Sex (text: male, female)\n", "# 3. Job (numeric: 0 - unskilled and non-resident, 1 - unskilled and resident, 2 - skilled, 3 - highly skilled)\n", "# 4. Housing (text: own, rent, or free)\n", "# 5. Saving accounts (text - little, moderate, quite rich, rich)\n", "# 6. Checking account (numeric, in DM - Deutsch Mark)\n", "# 7. Credit amount (numeric, in DM)\n", "# 8. Duration (numeric, in month)\n", "# 9. Purpose (text: car, furniture/equipment, radio/TV, domestic appliances, repairs, education, business, vacation/others)\n", "\n", "display(original_data.head())\n", "display(original_data.describe())\n", "display(original_data.info())\n", "\n", "# Display the unique values of the categorical features:\n", "print('Unique values of the categorical features:')\n", "for col in original_data.select_dtypes(include='object'):\n", " print(f'\\t- {col}: {original_data[col].unique()}')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Categorical features: Index(['Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account',\n", " 'Purpose'],\n", " dtype='object')\n", "Numerical features: Index(['Age', 'Credit amount', 'Duration'], dtype='object')\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeCredit_amountDurationCredit_RiskSex_femaleSex_maleJob_highlyskilledJob_skilledJob_unskilled_nonresidentJob_unskilled_resident...Checking_account_noneChecking_account_richPurpose_businessPurpose_carPurpose_domestic_appliancesPurpose_educationPurpose_furniture_equipmentPurpose_radio_TVPurpose_repairsPurpose_vacation_others
067116960010100...0000000100
1225951481100100...0000000100
2492096120010001...1000010000
3457882420010100...0000001000
4534870241010100...0001000000
\n", "

5 rows × 30 columns

\n", "
" ], "text/plain": [ " Age Credit_amount Duration Credit_Risk Sex_female Sex_male \\\n", "0 67 1169 6 0 0 1 \n", "1 22 5951 48 1 1 0 \n", "2 49 2096 12 0 0 1 \n", "3 45 7882 42 0 0 1 \n", "4 53 4870 24 1 0 1 \n", "\n", " Job_highlyskilled Job_skilled Job_unskilled_nonresident \\\n", "0 0 1 0 \n", "1 0 1 0 \n", "2 0 0 0 \n", "3 0 1 0 \n", "4 0 1 0 \n", "\n", " Job_unskilled_resident ... Checking_account_none Checking_account_rich \\\n", "0 0 ... 0 0 \n", "1 0 ... 0 0 \n", "2 1 ... 1 0 \n", "3 0 ... 0 0 \n", "4 0 ... 0 0 \n", "\n", " Purpose_business Purpose_car Purpose_domestic_appliances \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 1 0 \n", "\n", " Purpose_education Purpose_furniture_equipment Purpose_radio_TV \\\n", "0 0 0 1 \n", "1 0 0 1 \n", "2 1 0 0 \n", "3 0 1 0 \n", "4 0 0 0 \n", "\n", " Purpose_repairs Purpose_vacation_others \n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", "[5 rows x 30 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 954 entries, 0 to 953\n", "Data columns (total 30 columns):\n", " # Column Non-Null Count Dtype\n", "--- ------ -------------- -----\n", " 0 Age 954 non-null int64\n", " 1 Credit_amount 954 non-null int64\n", " 2 Duration 954 non-null int64\n", " 3 Credit_Risk 954 non-null int64\n", " 4 Sex_female 954 non-null int64\n", " 5 Sex_male 954 non-null int64\n", " 6 Job_highlyskilled 954 non-null int64\n", " 7 Job_skilled 954 non-null int64\n", " 8 Job_unskilled_nonresident 954 non-null int64\n", " 9 Job_unskilled_resident 954 non-null int64\n", " 10 Housing_free 954 non-null int64\n", " 11 Housing_own 954 non-null int64\n", " 12 Housing_rent 954 non-null int64\n", " 13 Saving_accounts_little 954 non-null int64\n", " 14 Saving_accounts_moderate 954 non-null int64\n", " 15 Saving_accounts_none 954 non-null int64\n", " 16 Saving_accounts_quite_rich 954 non-null int64\n", " 17 Saving_accounts_rich 954 non-null int64\n", " 18 Checking_account_little 954 non-null int64\n", " 19 Checking_account_moderate 954 non-null int64\n", " 20 Checking_account_none 954 non-null int64\n", " 21 Checking_account_rich 954 non-null int64\n", " 22 Purpose_business 954 non-null int64\n", " 23 Purpose_car 954 non-null int64\n", " 24 Purpose_domestic_appliances 954 non-null int64\n", " 25 Purpose_education 954 non-null int64\n", " 26 Purpose_furniture_equipment 954 non-null int64\n", " 27 Purpose_radio_TV 954 non-null int64\n", " 28 Purpose_repairs 954 non-null int64\n", " 29 Purpose_vacation_others 954 non-null int64\n", "dtypes: int64(30)\n", "memory usage: 223.7 KB\n" ] }, { "data": { "text/plain": [ "None" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "preprocessed_data = original_data.copy()\n", "\n", "# For savings and checking accounts, we will replace the missing values with 'none':\n", "preprocessed_data['Saving accounts'].fillna('none', inplace=True)\n", "preprocessed_data['Checking account'].fillna('none', inplace=True)\n", "\n", "# Dropping index column:\n", "preprocessed_data.drop(columns=['Unnamed: 0'], inplace=True)\n", "\n", "# Using pd.dummies to one-hot-encode the categorical features\n", "preprocessed_data[\"Job\"] = preprocessed_data[\"Job\"].map({0: 'unskilled_nonresident', 1: 'unskilled_resident',\n", " 2: 'skilled', 3: 'highlyskilled'})\n", "\n", "categorical_features = preprocessed_data.select_dtypes(include='object').columns\n", "numerical_features = preprocessed_data.select_dtypes(include='number').columns.drop('Credit Risk')\n", "print(f'Categorical features: {categorical_features}')\n", "print(f'Numerical features: {numerical_features}')\n", "\n", "preprocessed_data = pd.get_dummies(preprocessed_data, columns=categorical_features, dtype='int64')\n", "\n", "# Remapping the target variable to 0 and 1:\n", "preprocessed_data['Credit Risk'] = preprocessed_data['Credit Risk'].map({1: 0, 2: 1})\n", "\n", "# Make sure all column names are valid python identifiers (important for pd.query() calls):\n", "preprocessed_data.columns = preprocessed_data.columns.str.replace(' ', '_')\n", "preprocessed_data.columns = preprocessed_data.columns.str.replace('/', '_')\n", "\n", "display(preprocessed_data.head())\n", "display(preprocessed_data.info())" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "y = preprocessed_data['Credit_Risk']\n", "X = preprocessed_data.drop(columns='Credit_Risk')\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.7696335078534031\n", "ROC AUC: 0.6830357142857143\n" ] } ], "source": [ "clf = RandomForestClassifier(random_state=42)\n", "clf.fit(X_train, y_train)\n", "\n", "y_pred = clf.predict(X_test)\n", "\n", "print(f'Accuracy: {accuracy_score(y_test, y_pred)}')\n", "print(f'ROC AUC: {roc_auc_score(y_test, y_pred)}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Applying the Aggregate Explainer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agg_explainer = AggregatedExplainer(\n", " explainer_types=[LimeWrapper, ShapTabularTreeWrapper, AnchorWrapper], # Wrapped explainers whose explanations will be aggregated\n", " model=clf, X_train=X_train, categorical_feature_names=categorical_features, # Model and training data\n", " metrics=['nrc', 'sensitivity_spearman', 'faithfulness_corr'], # Metrics to be considered for the aggregation\n", " noise_gen_args={'encoding_dim': 5, 'epochs': 500}, # Arguments passed to the autoencoder noisy data generator\n", " evaluator_args={\"debug\": False} # Arguments passed to the evaluator class \n", ") " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature importance scores:
 featurescore
0Duration0.809114
1Purpose_furniture_equipment0.568681
2Checking_account_none0.497863
3Age0.232698
4Checking_account_little0.111593
5Credit_amount0.041399
6Checking_account_moderate0.039992
7Housing_free0.027987
8Sex_female0.018932
9Saving_accounts_moderate0.015219
10Sex_male0.014925
11Job_highlyskilled0.009129
12Housing_own0.008308
13Purpose_car0.008169
14Saving_accounts_little0.006912
15Purpose_radio_TV0.006733
16Job_skilled0.006117
17Saving_accounts_none0.006015
18Job_unskilled_resident0.005921
19Housing_rent0.004658
20Job_unskilled_nonresident0.004482
21Saving_accounts_quite_rich0.002573
22Purpose_education0.002123
23Purpose_repairs0.001111
24Purpose_business0.000986
25Purpose_vacation_others0.000942
26Checking_account_rich0.000937
27Saving_accounts_rich0.000645
28Purpose_domestic_appliances0.000545
\n", "   \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature importance ranking:
 featurerank
0Duration1
1Purpose_furniture_equipment2
2Checking_account_none3
3Age4
4Checking_account_little5
5Credit_amount6
6Checking_account_moderate6
7Housing_free7
8Sex_female7
9Saving_accounts_moderate8
10Sex_male8
11Job_highlyskilled8
12Housing_own8
13Purpose_car8
14Saving_accounts_little8
15Purpose_radio_TV8
16Job_skilled8
17Saving_accounts_none8
18Job_unskilled_resident8
19Housing_rent9
20Job_unskilled_nonresident9
21Saving_accounts_quite_rich9
22Purpose_education9
23Purpose_repairs9
24Purpose_business9
25Purpose_vacation_others9
26Checking_account_rich9
27Saving_accounts_rich9
28Purpose_domestic_appliances9
\n", "   " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Apply the aggregate explainer on a sample instance:\n", "sample_idx = 0\n", "agg_explanation = agg_explainer.explain_instance(X_test.iloc[sample_idx])\n", "\n", "display_side_by_side([agg_explanation, get_ranked_explanation(agg_explanation)], captions=['Feature importance scores:', 'Feature importance ranking:'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get information on the aggregate explainer's last explanation\n", "With the `get_last_explanation_info()` method, you can get a dataframe that contains each of the aggregated explanation models' performances on each of the metrics used to evaluate them. You are also given the weight each explanation model got from the MCDM algorithm, which is passed on to the rank aggregation step." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nrcsensitivity_spearmanfaithfulness_corrweight
LimeWrapper42.5045470.8391130.1101580.533547
ShapTabularTreeWrapper43.5312260.9642050.1670300.438026
AnchorWrapper42.4917090.5706680.4488990.585786
\n", "
" ], "text/plain": [ " nrc sensitivity_spearman faithfulness_corr \\\n", "LimeWrapper 42.504547 0.839113 0.110158 \n", "ShapTabularTreeWrapper 43.531226 0.964205 0.167030 \n", "AnchorWrapper 42.491709 0.570668 0.448899 \n", "\n", " weight \n", "LimeWrapper 0.533547 \n", "ShapTabularTreeWrapper 0.438026 \n", "AnchorWrapper 0.585786 " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agg_explainer.get_last_explanation_info()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LIME explanation:
 featurescore
0Checking_account_none0.060765
1Duration0.056665
2Checking_account_little0.035393
3Age0.028167
4Checking_account_moderate0.017343
5Housing_own0.013374
6Saving_accounts_little0.009622
7Credit_amount0.008884
8Housing_rent0.008133
9Sex_male0.007520
10Purpose_radio_TV0.006844
11Purpose_car0.006059
12Saving_accounts_none0.005942
13Housing_free0.005901
14Sex_female0.004537
15Saving_accounts_rich0.004407
16Purpose_education0.003162
17Job_skilled0.002704
18Saving_accounts_moderate0.002680
19Purpose_vacation_others0.002339
20Checking_account_rich0.002247
21Job_unskilled_nonresident0.001711
22Purpose_repairs0.001627
23Purpose_furniture_equipment0.001377
24Purpose_domestic_appliances0.001172
25Job_highlyskilled0.001144
26Saving_accounts_quite_rich0.001032
27Job_unskilled_resident0.000776
28Purpose_business0.000121
\n", "   \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SHAP explanation:
 featurescore
0Duration0.051965
1Checking_account_none0.048818
2Age0.044427
3Checking_account_little0.030740
4Checking_account_moderate0.025005
5Credit_amount0.018809
6Saving_accounts_moderate0.011132
7Purpose_furniture_equipment0.009065
8Sex_female0.007021
9Purpose_car0.006882
10Housing_free0.006844
11Saving_accounts_none0.004482
12Sex_male0.004218
13Job_unskilled_resident0.004134
14Saving_accounts_little0.003987
15Job_highlyskilled0.002708
16Saving_accounts_quite_rich0.001809
17Purpose_education0.001744
18Job_skilled0.001720
19Housing_own0.001668
20Purpose_repairs0.001397
21Purpose_vacation_others0.000760
22Purpose_business0.000601
23Saving_accounts_rich0.000569
24Checking_account_rich0.000486
25Purpose_radio_TV0.000356
26Housing_rent0.000178
27Purpose_domestic_appliances0.000000
28Job_unskilled_nonresident0.000000
\n", "   \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Anchor explanation:
 featurescore
0Purpose_furniture_equipment0.813893
1Age0.503277
2Duration0.433814
3Sex_female0.313237
4Housing_own0.296199
5Checking_account_little0.283093
6Saving_accounts_rich0.000000
7Purpose_repairs0.000000
8Purpose_radio_TV0.000000
9Purpose_education0.000000
10Purpose_domestic_appliances0.000000
11Purpose_car0.000000
12Purpose_business0.000000
13Checking_account_rich0.000000
14Checking_account_none0.000000
15Checking_account_moderate0.000000
16Saving_accounts_none0.000000
17Saving_accounts_quite_rich0.000000
18Credit_amount0.000000
19Saving_accounts_moderate0.000000
20Saving_accounts_little0.000000
21Housing_rent0.000000
22Housing_free0.000000
23Job_unskilled_resident0.000000
24Job_unskilled_nonresident0.000000
25Job_skilled0.000000
26Job_highlyskilled0.000000
27Sex_male0.000000
28Purpose_vacation_others0.000000
\n", "   " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display_side_by_side(agg_explainer.last_explanation_components, captions=['LIME explanation:', 'SHAP explanation:', 'Anchor explanation:'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluating the aggregate explainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The ExplanationModelEvaluator Class\n", "This class holds all definitions for the metrics used to evaluate the explanation models. The aggregate explainer maintains an instance of this class in order to use its evaluations in the aggregation process. It is designed so that it can be used on any explainer that follows the interface and behavior conventions of the `explainers.py` file." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using the internal ExplanationModelEvaluator instance\n", "In order to be used, the ExplanationModelEvaluator class must be instantiated and its `init()` method must be called. This process, however, is somewhat time-consuming, since one of the metrics defined by this class relies on generating a noisy variation of the training data, and, to do that, an autoencoder is trained with tensorflow.\n", "\n", "However, this is usually not necessary, since the AggregateExplainer class maintains its own instance of the ExplanationModelEvaluator class, which can be used normally." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# ++ Usual instantiation of the ExplanationModelEvaluator class:\n", "#\n", "# evaluator = ExplanationModelEvaluator(clf, X_train, categorical_features)\n", "# evaluator.init() # Takes some time to train the autoencoder\n", "\n", "# ++ Or, grab the one maintained by the AggregatedExplainer:\n", "evaluator = agg_explainer.xai_evaluator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [WORKAROUND] Applying the sensitivity metric to the aggregate explainer:\n", "One of the metrics defined in the ExplanationModelEvaluator class is the sensitivity metric. The way it works requires it to create several new instances of the explanation model being evaluated, since they each need to be fit to a different noisy variation of the training data. This process is very slow, and therefore multiprocessing is used in the `sensitivity()` function to distribute the workload. This, however, poses an issue when evaluating the sensitivity of the aggregate explainer model, since it may also use the sensitivity metric itself to perform the aggregation, which means a child process would have to create another child process, which usually is not allowed.\n", "\n", "As of now, in order to apply the sensitivity metric to the aggregate explainer, you must use a variation of its implementation that does the calculation without multiprocessing. A sequential version of the `sensitivity()` metric is provided by the `_sensitivity_sequential()` function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9400656814449916" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluator._sensitivity_sequential(\n", " agg_explainer, \n", " X_test.iloc[sample_idx],\n", " extra_explainer_params={ # Must specify everything the explainer needs to be instantiated\n", " \"explainer_types\": [LimeWrapper, ShapTabularTreeWrapper, AnchorWrapper],\n", " \"evaluator\": agg_explainer.xai_evaluator # Remember to resue the same evaluator instance, otherwise the autoencoder will be retrained for every iteration\n", " },\n", " iterations=3,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Full evalution of the aggregate explainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's one way of evaluating the aggregate explainer and comparing it to the explainers whose explanations were aggregated. In this example, the aggregate explainer was evaluated with the same metrics it used to internally evaluate each of the component models. The `get_last_explanation_info()` function was used to retrieve the metrics that were calculated internally, so they aren't calculated twice." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "faithfulness = evaluator.faithfullness_correlation(agg_explainer, X_test.iloc[sample_idx])\n", "sensitivity = evaluator._sensitivity_sequential( # sequential version of sensitivity must be used at this time\n", " agg_explainer, X_test.iloc[sample_idx],\n", " extra_explainer_params={\n", " \"explainer_types\": [LimeWrapper, ShapTabularTreeWrapper, AnchorWrapper],\n", " \"evaluator\": agg_explainer.xai_evaluator\n", " },\n", " iterations=10\n", " )\n", "nrc = evaluator.nrc(agg_explainer, X_test.iloc[sample_idx])\n", "\n", "metrics = agg_explainer.get_last_explanation_info().drop(columns='weight')\n", "\n", "metrics.at[AggregatedExplainer.__name__, 'faithfulness_corr'] = faithfulness\n", "metrics.at[AggregatedExplainer.__name__, 'sensitivity_spearman'] = sensitivity\n", "metrics.at[AggregatedExplainer.__name__, 'nrc'] = nrc" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nrcsensitivity_spearmanfaithfulness_corr
LimeWrapper46.1526200.8592120.355147
ShapTabularTreeWrapper42.6482010.9548430.154678
AnchorWrapper18.4428140.6686670.079319
AggregatedExplainer44.5794870.9137440.320685
\n", "
" ], "text/plain": [ " nrc sensitivity_spearman faithfulness_corr\n", "LimeWrapper 46.152620 0.859212 0.355147\n", "ShapTabularTreeWrapper 42.648201 0.954843 0.154678\n", "AnchorWrapper 18.442814 0.668667 0.079319\n", "AggregatedExplainer 44.579487 0.913744 0.320685" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Using the `xai_agg.exp_utils.evaluate_aggregate_explainer()` function\n", "Utility function to evaluate the aggregate explainer, varying its settings. For each of the aggregate explainer's parameters (explainer components, mcdm algorighm, aggregation algorithm), the function accepts a list of possible values; it'll iterate over every possible value combination, checking n_instances, and will return the results as a list of lists of dataframes, one dataframe for each instance check, and one list of dataframes for each setting configuration." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from xai_agg.exp_utils import evaluate_aggregate_explainer\n", "\n", "results, metadata = evaluate_aggregate_explainer(\n", " clf, X_train, X_test, categorical_features, # Model and data\n", " explainer_components_sets=[[LimeWrapper, ShapTabularTreeWrapper, AnchorWrapper]], # Wrapped explainer sets to be tested\n", " mcdm_algs=[pymcdm.methods.TOPSIS()], # MCDM algorithms to be tested\n", " aggregation_algs=[\"wsum\"], # Aggregation algorithms to be tested\n", " metrics_sets=[['nrc', 'sensitivity_spearman', 'faithfulness_corr']], # Metric sets to be tested\n", " n_instances=1, # Number of instances per setting to run the evaluation on\n", " mp_jobs=5 # Number of jobs to run in parallel (DECREASE THIS VALUE WHEN LOW RAM IS AVAILABLE)\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[ nrc sensitivity_spearman faithfulness_corr\n", " LimeWrapper 45.304455 0.838916 0.182748\n", " ShapTabularTreeWrapper 44.518230 1.000000 0.240986\n", " AnchorWrapper 35.929599 0.616926 0.326659\n", " AggregateExplainer 48.324269 0.881232 0.286450]]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'indexes': array([110]),\n", " 'configs': [{'explainer_components': [xai_agg.explainers.LimeWrapper,\n", " xai_agg.explainers.ShapTabularTreeWrapper,\n", " xai_agg.explainers.AnchorWrapper],\n", " 'metrics': ['nrc', 'sensitivity_spearman', 'faithfulness_corr'],\n", " 'mcdm_alg': ,\n", " 'aggregation_alg': 'wsum'}]}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "experiment_run = ExperimentRun(metadata, results)\n", "\n", "display(experiment_run.results)\n", "display(experiment_run.metadata)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nrcsensitivity_spearmanfaithfulness_corr
AggregateExplainer48.3242690.8812320.286450
AnchorWrapper35.9295990.6169260.326659
LimeWrapper45.3044550.8389160.182748
ShapTabularTreeWrapper44.5182301.0000000.240986
\n", "
" ], "text/plain": [ " nrc sensitivity_spearman faithfulness_corr\n", "AggregateExplainer 48.324269 0.881232 0.286450\n", "AnchorWrapper 35.929599 0.616926 0.326659\n", "LimeWrapper 45.304455 0.838916 0.182748\n", "ShapTabularTreeWrapper 44.518230 1.000000 0.240986" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get mean results for a specific setting:\n", "\n", "desired_setting = 0\n", "get_expconfig_mean_results(experiment_run, desired_setting)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }