SHAPを用いたモデルの解釈
なんでこのモデルがこのような予測をしたのかを説明する、解釈性は近年ますます注目されています。モデルの解釈を可能にするために様々な手法が提案されていますが、その手法の一つであるSHAP(SHapley Additive exPlanations)についてまとめます。今回はTitanicのデータセット、LIghtGBMモデルを用いて、SHAPで結果の解釈を行いました。
SHAPとは
NIPS2017の「A Unified Approach to Interpreting Model Predictions」で提案された手法です。
SHAPはモデルの予測結果に対する各特徴量の寄与度を求めるための手法で、寄与度として協力ゲーム理論のShapley Valueを用いています。 協力ゲーム理論のShapley Valueとは簡単にいうと、複数人で協力して報酬を得たときそれを適切に分配したときのそれぞれの分配額です。 特徴量をプレイヤーと見立ててこのShapley Valueを求めることで、モデルの予測結果に対する寄与度とします。(本当はSHAP値というちょっと違うものを使う)
実験
使用したライブラリ
SHAP
GBDT
LightGBM
データセット
データセットはタイタニックのデータを使用します www.kaggle.com
また今回はSHAPを使うのが目的のため、いくつか説明変数を減らして以下の変数を使います。
変数名 | 意味 |
---|---|
PClass | チケットクラス。1,2,3の順でランクが高い |
Sex | 性別 |
Age | 年齢 |
SibSp | 同乗している兄弟/配偶者の数 |
Parch | 同乗している親/子供の数 |
Fare | 料金 |
Embarked | 出港地 |
Survived | 生存フラグ(0=死亡、1=生存) |
言わずもがなですが、生き残るかどうかを予測するタスクになります。
コード&結果
ちょくちょく端折ります。
ライブラリのインポート。今回はこれだけ。
import os import shap import pandas as pd import lightgbm as lgb from sklearn.model_selection import train_test_split
データの読み込み。 訓練データのみ使用。
train = pd.read_csv(TRAINPATH)
変数を減らして訓練とバリデーションに分ける。
use_columns = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'] X_train, X_val, y_train, y_val = train_test_split(train[use_columns], train['Survived'], test_size=0.2)
学習。パラメータは雑。
d_train = lgb.Dataset(X_train, label=y_train) d_val = lgb.Dataset(X_val, label=y_val) params = { "objective": "binary", "metric": {"binary_logloss",'auc'}, "verbose": -1, } model = lgb.train(params, d_train, 100, valid_sets= d_val, early_stopping_rounds= 50)
SHAPのExplainerは複数あります。
今回はGBDTを用いるのでTreeExplainer
を使います。
shap.initjs() explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X_train)
1行目のshap.initjs()
はJavascriptで描画するためのおまじないです。NodeJSのインストールが必要です。
Force Plot
ForcePlotは与えられたShapValueと変数の寄与度を視覚化します。
link=logit
を設定することで出力を確率に変換しています。
これによって個別の変数がどれだけ生存確率を上げたか/下げたかが見て取れます。赤色が生存確率上昇に寄与した変数、青色が生存確率現象に寄与した変数です。
1レコードをForcePlotした結果は以下のようになります。
index = 1 #何行目をforceplotするか shap.force_plot( base_value=explainer.expected_value[1], shap_values=shap_values[1][index,:], features=X_train.iloc[index,:], link='logit', matplotlib=True )
年齢が2歳
であることが生存確率を上げていますが、Pclass=3
(一番低いチケットクラス)やSibSp=4
が生存確率を下げていることが分かります。waterfall_plot()
を用いても表示方法が変わるだけで同じようなことができます。
また、全レコードに対して可視化することもでき、その結果は以下のようになります。
特徴量の分布が似ているもの同士に並べてくれたり、SHAP ValueでSortして表示してくれたりします。
Waterfall Plot
ForcePlotの表示をわかりやすくしたものです。
値はSHAP Valueです。
index = 1 shap.waterfall_plot( expected_value=explainer.expected_value[1], shap_values=shap_values[1][index,:], features=X_train.iloc[index,:], show=True )
Dependence Plot
Dependence Plotでは横軸に実際の値、縦軸にSHAP Valueが取られています。 例えば年齢をDependence Plotしてみます。
shap.dependence_plot( ind="Age", shap_values=shap_values[1], features=X_train, interaction_index=None )
これを見ると子供であることが生存によく寄与していることなどが分かります。
Decision Plot
Decision Plotでは予測の過程を可視化することができ、より個々の影響を見ることができます。 試しに10人分plotしてみると以下のようになります。
shap.decision_plot( base_value=explainer.expected_value[1], shap_values=shap_values[1][:10,:], features=X_train.iloc[:10,:], link="logit", show=True )
どの変数でSHAP Valueが上がったのか/下がったのかが分かりやすいですねー。
Summary Plot
Summary Plot はもっと大局的に結果を見たい場合に便利です。 バイオリンプロット的なことができます。点が個々のサンプルを表し、予測結果への寄与度が大きい変数順に上から並んでいます。
shap.summary_plot( shap_values=shap_values[1], features=X_train, max_display=5 )
plot_type='bar'
とすると、シンプルに棒グラフで表示できます。
注意
SHAPは、"解きたい問題の解釈" をしているわけではなく、あくまで "学習済みモデルの解釈" をしようとしています。そのためモデルが良いモデルでなかったら、解釈も良いものでないということに注意する。