๋ฌ๋์คํผ์ฆ ์์ ์ ๋ฆฌ
< ์ด์ ๊ธ >
https://silvercoding.tistory.com/67
[Bank Marketing๋ฐ์ดํฐ ๋ถ์] 2. python ๋ถ์คํ Boosting, XGBoost ์ฌ์ฉ
๋ฌ๋์คํผ์ฆ ์์ ์ ๋ฆฌ < ์ด์ ๊ธ > https://silvercoding.tistory.com/66 https://silvercoding.tistory.com/65 https://silvercoding.tistory.com/64 https://silvercoding.tistory.com/63?category=967543 https..
silvercoding.tistory.com
'๊ฒฐ๋ก ์ด ๋ฌด์์ธ์ง' ๋ฅผ ์ค๋ช
ํ๋ ๊ฒ์ ๋ฐ์ดํฐ์ฌ์ด์ธํฐ์คํธ๋ก์์ ์ค์ํ ์
๋ฌด์ด๋ค.
์์ธก ๊ฒฐ๊ณผ๋ง ๋ณด๊ณ ๋ ๋ชจ๋ธ์ด ์ด๋ค ํจํด์ ์ด์ฉํ์ฌ ์์ธก์ ์คํํ๊ฒ ๋์๋์ง, ์ ๊ทธ๋ ๊ฒ ์์ธกํ๋์ง ์ค๋ช
ํ ์ ์๋ค. ๊ทธ๋ ๊ฒ ๋๋ฉด ๋ค๋ฅธ ๋ถ์ผ์ ํ์
์๋ค์ ์ ๋ขฐ๋ฅผ ์๊ฒ๋ ๊ฒ์ด๋ค.
๋น์ฆ๋์ค์ ๊ด์ ์์ ์๋ฅผ ๋ค์ด๋ณธ๋ค. ๋จธ์ ๋ฌ๋์ ํตํ์ฌ ์ํ ํฅํ์ฑ์ ์ ์์ธกํ๋ ํ๋ก์ ํธ์์ ํฅํ ์คํจ๋ผ๋ ์์ธก์ด ๋์๋ค๊ณ ํ์ ๋, ์ด๋ป๊ฒ ํฅํ์คํจ๋ฅผ ๋ง์ ๊ฒ์ด๋๊ณ ์ง๋ฌธ์ด ๋ค์ด์ฌ ์๋ ์๋ค. ๊ธฐ์กด์ ์ทจ์ฝ์ ์ ๋ณด์ํ์ง ๋ชปํ๋ค๋ฉด ๋น์ฆ๋์ค์ ๊ด์ ์์ ์๋ฏธ๊ฐ ์๋ค.
๋ฐ๋ผ์ ๊ฒฐ๊ณผ๋ฅผ ์ค๋ช ํ ์ ์๋ ๊ฒ์ ์์ฃผ ์ค์ํ๋ค. ์ด ๋ ๋ณ์์ค์๋๋ฅผ ํ์ฉํ ์ ์๋ค. ์์ธก์ ํฐ ์ํฅ์ ๋ฏธ์น ๋ณ์์, ํน์ ๋ณ์๊ฐ ์ด๋ป๊ฒ ์ํฅ์ ๋ฏธ์ณค๋์ง ์ฌ์ธํ๊ฒ ํ์ธํด๋ณผ ์ ์๋ค.
๋ณ์์ค์๋
- ๋ชจ๋ธ์ ํ์ฉํ input ๋ณ์ ์ค์์ ์ด๋ค ๊ฒ์ด target ๊ฐ์ ๊ฐ์ฅ ํฐ ์ํฅ์ ๋ฏธ์ณค๋?
- ํด๋น ์ค์๋๋ฅผ ์์นํ์ํจ ๊ฒ
- treeํ ๋ชจ๋ธ (์์ฌ๊ฒฐ์ ๋๋ฌด, ๋๋คํฌ๋ ์คํธ) ์์ ๊ณ์ฐ ๊ฐ๋ฅ
์ด์ ๊ธ์ treeํ ๋ชจ๋ธ์ธ random forest์ xgboost์์ ๋ณ์์ค์๋ ๊ณ์ฐ์ ์คํํ์๋ค.
์์ฌ๊ฒฐ์ ๋๋ฌด์์์ ๋ณ์์ค์๋
- ํด๋น input ๋ณ์๊ฐ ์์ฌ๊ฒฐ์ ๋๋ฌด์ ๊ตฌ์ถ์์ ์ผ๋ง๋ ๋ง์ด ์ฐ์ด๋
- ํด๋น ๋ณ์๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ถ๊ธฐ๋ฅผ ํ์ ๋ ๊ฐ ๊ตฌ๊ฐ์ ๋ณต์ก๋๊ฐ ์ผ๋ง๋ ์ค์ด๋๋๊ฐ?
shapley ๊ฐ
: ๊ฐ ๋ณ์๊ฐ ์์ธก ๊ฒฐ๊ณผ๋ฌผ์ ์ฃผ๋ ์ํฅ๋ ฅ์ ํฌ๊ธฐ
: ํด๋น ๋ณ์๊ฐ ์ด๋ค ์ํฅ์ ์ฃผ๋๊ฐ
(์) ์ถ๊ตฌ ์ ์ A , ์ํ ํ B
- ๊ฐ ์ ์๊ฐ ํ ์ฑ์ ์ ์ฃผ๋ ์ํฅ๋ ฅ ํฌํค
- ํด๋น ์ ์๊ฐ ์ด๋ ํ ์ํฅ์ ์ฃผ๋๊ฐ
- (์ ์ A๊ฐ ์๋ ํ B์ ์น๋ฅ ) - (์ ์ A๊ฐ ์๋ ํ B์ ์น๋ฅ = 7%
shap value ์ค์ต
shap value ์ค์ต์ ์ค์ ์ ๋๊ธฐ ์ํด Xgboost ํ์ต๊น์ง ์ ์ ํ๋ ๊ทธ๋๋ก ์คํํด์ค๋ค.
๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
import os
import pandas as pd
import numpy as np
os.chdir('./data') # ๋ณธ์ธ ๊ฒฝ๋ก
data = pd.read_csv("bank-additional-full.csv", sep = ";")
์ด์ ๊ธ์์ ์ฌ์ฉํ์๋ ์๊ธ ๊ฐ์ ์ฌ๋ถ ๋ฐ์ดํฐ์ ์ด๋ค.
data = pd.get_dummies(data, columns = ['job','marital','education','default','housing','loan','contact','month','day_of_week','poutcome'])
๋ฒ์ฃผํ ๋ณ์๋ฅผ get_dummies๋ฅผ ์ด์ฉํ์ฌ ์ํซ์ธ์ฝ๋ฉ ํด์ค๋ค.
data['y'].value_counts()
๋ถ๋ฅ ๋ชจ๋ธ์ด๊ธฐ ๋๋ฌธ์ ๋ชฉํ๋ณ์๋ ๋น์ฐํ ๋ฒ์ฃผํ ๋ณ์๋ก ๋์ด์๋ค.
data['y'] = np.where( data['y'] == 'no', 0, 1)
ํ์ง๋ง shap value ํจํค์ง๋ ๋ชฉํ๋ณ์๊ฐ ์์นํ์ด์ด์ผ ์ ์๋ํ๊ธฐ ๋๋ฌธ์ ์์นํ ์์ผ์ค๋ค.
Xgboost ํ์ต
input_var = ['age', 'duration', 'campaign', 'pdays', 'previous', 'emp.var.rate',
'cons.price.idx', 'cons.conf.idx', 'euribor3m', 'nr.employed',
'job_admin.', 'job_blue-collar', 'job_entrepreneur', 'job_housemaid',
'job_management', 'job_retired', 'job_self-employed', 'job_services',
'job_student', 'job_technician', 'job_unemployed', 'job_unknown',
'marital_divorced', 'marital_married', 'marital_single',
'marital_unknown', 'education_basic.4y', 'education_basic.6y',
'education_basic.9y', 'education_high.school', 'education_illiterate',
'education_professional.course', 'education_university.degree',
'education_unknown', 'default_no', 'default_unknown', 'default_yes',
'housing_no', 'housing_unknown', 'housing_yes', 'loan_no',
'loan_unknown', 'loan_yes', 'contact_cellular', 'contact_telephone',
'month_apr', 'month_aug', 'month_dec', 'month_jul', 'month_jun',
'month_mar', 'month_may', 'month_nov', 'month_oct', 'month_sep',
'day_of_week_fri', 'day_of_week_mon', 'day_of_week_thu',
'day_of_week_tue', 'day_of_week_wed', 'poutcome_failure',
'poutcome_nonexistent', 'poutcome_success']
y ์ปฌ๋ผ์ ์ ์ธํ ์ธํ๋ณ์๋ฅผ ๋ฆฌ์คํธ์ ๋ชจ๋ ๋ด์์ค๋ค.
from xgboost import XGBRegressor
์์นํ์ผ๋ก ์์ธก์ ์งํํ๊ธฐ ์ํด XBGRegressor ํ๊ท๋ชจ๋ธ์ ์ํฌํธ ํด์ค๋ค.
xgb = XGBRegressor( n_estimators = 300, learning_rate=0.1 )
xgb.fit(data[input_var], data['y'])
Xgboost ํ์ต์ ์งํํ๋ค.
Shap Value ์์
import shap
shap ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ import ํด์ค๋ค.
(1) ๋ณ์์ค์๋
explainer = shap.TreeExplainer(xgb)
shap_values = explainer.shap_values( data[input_var] )
shap.TreeExplainer์ ์ธ์์ ํ์ตํ ๋ชจ๋ธ xgb๋ฅผ ๋ฃ์ด ๊ฐ์ฒด๋ฅผ ์ ์ฅํด์ค๋ค. ๊ทธ๋ค์ explainer.shap_values์ ์ธ์์ ๋ฐ์ดํฐ์ ์ ์ธํ๊ฐ์ ๋ฃ์ด์ค๋ค.
shap.summary_plot( shap_values , data[input_var] , plot_type="bar" )
shap.summary_plot์ ์ฌ์ฉํ์ฌ ๋ณ์์ค์๋ ๊ทธ๋ํ๋ฅผ ๊ทธ๋ ค์ค๋ค. ๊ฐ์ฅ ๋์ ๋ณ์๋ duration์ด๋ค. duration์ ์ ํ์๊ฐ์ด๋ค. ์ ํ์๊ฐ์ ๊ธธ์ด๊ฐ ์ด ๋ชจ๋ธ์ ์์ธก์ ๊ฐ์ฅ ์ํฅ์ ๋ง์ด ๋ฏธ์น๋ค๋ ์๋ฏธ์ด๋ค.
(2) dependence plot
: ํน์ input ๋ณ์์ target ๋ณ์์์ ๊ด๊ณ๋ฅผ ํํํ๋ ๊ฒ
: ์ ์ ๊ฐ๊ฐ์ row๋ฅผ ์๋ฏธ(๋ฐ์ดํฐ ํ๊ฐ), ํ๊ฒ๋ณ์์ ๋ฏธ์น ์ํฅ = y
: ํด๋น ๋ณ์๊ฐ ์ด๋ป๊ฒ ์ํฅ์ ๋ฏธ์ณค๋์ง ์ฌ์ธํ๊ฒ ๋ณผ ์ ์๋ค.
shap.dependence_plot( 'duration' , shap_values , data[input_var] )
duration์ ๊ทธ๋ํ๋ฅผ ๋ณด๋ฉด duration์ ๋๋ถ๋ถ์ด 3000 ๋ฏธ๋ง์ ์กด์ฌํ๊ณ , ๊ทธ ์ค์์๋ duration์ด 50์ด์์ฏค ๋๋ฉด ์ข์ ์ํฅ๋ ฅ์ ๋ผ์ณ 1์ผ ๊ฐ๋ฅ์ฑ์ด ๋์์ง๋ค๊ณ ํด์๋๋ค. (shpa value for duration์ด 0๋ณด๋ค ํฐ ๋ฐ์ดํฐ๊ฐ ๋ง์)
shap.dependence_plot( 'nr.employed' , shap_values , data[input_var] )
5020์ฏค ๋๋ ์ง์ ์์ ์ํฅ๋ ฅ์ด ์์๊ฐ ๋๋ค. ๊ทธ๋ฆฌ๊ณ 5100์ด ๋์ด๊ฐ๊ณ ๋ ์์์ ์ํฅ๋ ฅ๋ฐ์ ์๋ค. (-> 0์ผ ๊ฐ๋ฅ์ฑ์ด ๋์) ๊ทธ ์ด์ ์๋ ์ํฅ๋ ฅ์ด ๋์ผ๋ฏ๋ก ์ข์ ์ํฅ๋ ฅ์ ๋ผ์น๋ค. (-> 1์ผ ๊ฐ๋ฅ์ฑ์ด ๋์)
shap.dependence_plot( 'euribor3m' , shap_values , data[input_var] )
์์์ ์์๊ฐ ๋น์ทํ๊ฒ ๋ถํฌ๋์ด์๋ ๊ฒ ๊ฐ์ ๋ณด์ธ๋ค. ์ด ์ค์์ ์์๊ฐ ์ผ๋ง ์๊ณ ์์๊ฐ ๋ง์ ๊ตฌ๊ฐ์ ์ฐพ์๋ณด๋ฉด 1.3~1.4 - 2, 4-5 ๊ฐ ์๋ค. ํด๋น ๊ตฌ๊ฐ์ผ ๋ 1์ผ ๊ฐ๋ฅ์ฑ์ด ๋๋ค๊ณ ํด์ํ ์ ์๋ค.
shap.dependence_plot( 'cons.conf.idx' , shap_values , data[input_var] )
์ ์ฒด์ ์ผ๋ก ์์๋ฅผ ์ด๋ฃจ๊ณ ์์์ ์ ์ ์๋ค. -45์ดํ์ผ ๋๋ 1์ผ ๊ฐ๋ฅ์ฑ์ด ๋์์ง๋ค๊ณ ํด์ํ ์ ์๋ค.
shap.dependence_plot( 'pdays' , shap_values , data[input_var] )
pdays๊ฐ 0์ผ๋ ๋๋ค์์ ๋ฐ์ดํฐ๊ฐ 1์ผ ๊ฐ๋ฅ์ฑ์ด ๋์์ง ๊ฒ์ด๋ผ ์์ํ ์ ์๋ค.
(3) force plot
: ํน์ ๊ฐ์ด ์ด๋ป๊ฒ ์์ธก๋์๋์ง๋ฅผ ์๊ฐํ
prediction = xgb.predict(data[input_var])
data['pred'] = prediction
shap.initjs()
shap.force_plot( explainer.expected_value , shap_values[41187] , data[input_var].iloc[41187] )
411187๋ฒ์งธ ๋ฐ์ดํฐ๋ 0.09๊ฐ ๋์๋๋ฐ, ๋จ์ด๋จ๋ฆฌ๋ ๋ณ์์ ์ฌ๋ฆฌ๋ ๋ณ์๊ฐ ๊ณจ๊ณ ๋ฃจ ๋ถํฌ๋์ด ์๋ค.
shap.force_plot( explainer.expected_value , shap_values[0] , data[input_var].iloc[41187] )
0์ ๊ฑฐ์ ๊ฐ๊น๊ฒ ์์ธก๋ 0๋ฒ์งธ ๋ฐ์ดํฐ๋ ๊ฑฐ์ ๋ชจ๋ ๋ณ์๊ฐ ์์์ ์ํฅ๋ ฅ์ ๋ผ์น ๊ฒ์ ๋ณผ ์ ์๋ค.
41183๋ฒ์งธ ๋ฐ์ดํฐ๋ ์์ ์ํฅ๋ ฅ์ด ํจ์ฌ ๋์ ๊ฒ์ ๋ณผ ์ ์๋ค. ๋ฐ๋ผ์ 0.88์ ๊ฒฐ๊ณผ๊ฐ ๋์๊ณ , ์ ๋ต์ 1๋ก, ๊ทผ์ ํ๊ฒ ๋งํ๋ค.
์ด๋ ๊ฒ shap ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ๋ณ์๊ฐ ์์ธก์ ์ด๋ ํ ์ํฅ์ ๋ฏธ์ณค๋์ง ์ฌ์ธํ๊ฒ ์์๋ณผ ์ ์์๋ค.