情報系大学院生の勉強メモ

主に機械学習,Python

SHAPを用いたモデルの解釈

なんでこのモデルがこのような予測をしたのかを説明する、解釈性は近年ますます注目されています。モデルの解釈を可能にするために様々な手法が提案されていますが、その手法の一つであるSHAP(SHapley Additive exPlanations)についてまとめます。今回はTitanicのデータセット、LIghtGBMモデルを用いて、SHAPで結果の解釈を行いました。

SHAPとは

NIPS2017の「A Unified Approach to Interpreting Model Predictions」で提案された手法です。

論文はこちら

SHAPはモデルの予測結果に対する各特徴量の寄与度を求めるための手法で、寄与度として協力ゲーム理論のShapley Valueを用いています。 協力ゲーム理論のShapley Valueとは簡単にいうと、複数人で協力して報酬を得たときそれを適切に分配したときのそれぞれの分配額です。 特徴量をプレイヤーと見立ててこのShapley Valueを求めることで、モデルの予測結果に対する寄与度とします。(本当はSHAP値というちょっと違うものを使う)

実験

使用したライブラリ

SHAP

github.com

GBDT

LightGBM

データセット

データセットタイタニックのデータを使用します www.kaggle.com

また今回はSHAPを使うのが目的のため、いくつか説明変数を減らして以下の変数を使います。

変数名 意味
PClass チケットクラス。1,2,3の順でランクが高い
Sex 性別
Age 年齢
SibSp 同乗している兄弟/配偶者の数
Parch 同乗している親/子供の数
Fare 料金
Embarked 出港地
Survived 生存フラグ(0=死亡、1=生存)

言わずもがなですが、生き残るかどうかを予測するタスクになります。

コード&結果

ちょくちょく端折ります。

ライブラリのインポート。今回はこれだけ。

import os
import shap
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split

データの読み込み。 訓練データのみ使用。

train = pd.read_csv(TRAINPATH)

変数を減らして訓練とバリデーションに分ける。

use_columns = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']
X_train, X_val, y_train, y_val = train_test_split(train[use_columns], train['Survived'], test_size=0.2)

学習。パラメータは雑。

d_train = lgb.Dataset(X_train, label=y_train)
d_val = lgb.Dataset(X_val, label=y_val)

params = {
    "objective": "binary",
    "metric": {"binary_logloss",'auc'},
    "verbose": -1,
}

model = lgb.train(params, d_train, 100, valid_sets= d_val, early_stopping_rounds= 50)

SHAPのExplainerは複数あります。 今回はGBDTを用いるのでTreeExplainerを使います。

shap.initjs()
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)

1行目のshap.initjs()Javascriptで描画するためのおまじないです。NodeJSのインストールが必要です。

Force Plot

ForcePlotは与えられたShapValueと変数の寄与度を視覚化します。 link=logitを設定することで出力を確率に変換しています。 これによって個別の変数がどれだけ生存確率を上げたか/下げたかが見て取れます。赤色が生存確率上昇に寄与した変数、青色が生存確率現象に寄与した変数です。 1レコードをForcePlotした結果は以下のようになります。

index = 1 #何行目をforceplotするか
shap.force_plot(
    base_value=explainer.expected_value[1], 
    shap_values=shap_values[1][index,:], 
    features=X_train.iloc[index,:],
    link='logit',
    matplotlib=True
)

force_plot 年齢が2歳であることが生存確率を上げていますが、Pclass=3(一番低いチケットクラス)やSibSp=4が生存確率を下げていることが分かります。waterfall_plot()を用いても表示方法が変わるだけで同じようなことができます。

また、全レコードに対して可視化することもでき、その結果は以下のようになります。 force_plot2

特徴量の分布が似ているもの同士に並べてくれたり、SHAP ValueでSortして表示してくれたりします。

Waterfall Plot

ForcePlotの表示をわかりやすくしたものです。

値はSHAP Valueです。

index = 1
shap.waterfall_plot(
    expected_value=explainer.expected_value[1], 
    shap_values=shap_values[1][index,:], 
    features=X_train.iloc[index,:],
    show=True
)

waterfall

Dependence Plot

Dependence Plotでは横軸に実際の値、縦軸にSHAP Valueが取られています。 例えば年齢をDependence Plotしてみます。

shap.dependence_plot(
    ind="Age",
    shap_values=shap_values[1],
    features=X_train,
    interaction_index=None
)

dependence これを見ると子供であることが生存によく寄与していることなどが分かります。

Decision Plot

Decision Plotでは予測の過程を可視化することができ、より個々の影響を見ることができます。 試しに10人分plotしてみると以下のようになります。

shap.decision_plot(
    base_value=explainer.expected_value[1], 
    shap_values=shap_values[1][:10,:], 
    features=X_train.iloc[:10,:],
    link="logit",
    show=True
)

decision

どの変数でSHAP Valueが上がったのか/下がったのかが分かりやすいですねー。

Summary Plot

Summary Plot はもっと大局的に結果を見たい場合に便利です。 バイオリンプロット的なことができます。点が個々のサンプルを表し、予測結果への寄与度が大きい変数順に上から並んでいます。

shap.summary_plot(
    shap_values=shap_values[1], 
    features=X_train,
    max_display=5
)

summary

plot_type='bar'とすると、シンプルに棒グラフで表示できます。 summary_bar

注意

SHAPは、"解きたい問題の解釈" をしているわけではなく、あくまで "学習済みモデルの解釈" をしようとしています。そのためモデルが良いモデルでなかったら、解釈も良いものでないということに注意する。

まとめ

  • SHAPを用いると、「どの特徴量がどのくらい予測結果に寄与したか」によって機械学習モデルを解釈できます。

  • SHAPは協力ゲーム理論のShapley Valueの考え方を用いています。

  • SHAPのExplainerは複数あります。今回はTreeExplainerを用いて様々なプロットを行いました。