7 분 소요

PCA 활용

default of credit card clients dataset

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.read_excel('./datasets/pca_credit_card.xls', header=1)
df.head(3).T
0 1 2
ID 1 2 3
LIMIT_BAL 20000 120000 90000
SEX 2 2 2
EDUCATION 2 2 2
MARRIAGE 1 2 2
AGE 24 26 34
PAY_0 2 -1 0
PAY_2 2 2 0
PAY_3 -1 0 0
PAY_4 -1 0 0
PAY_5 -2 0 0
PAY_6 -2 2 0
BILL_AMT1 3913 2682 29239
BILL_AMT2 3102 1725 14027
BILL_AMT3 689 2682 13559
BILL_AMT4 0 3272 14331
BILL_AMT5 0 3455 14948
BILL_AMT6 0 3261 15549
PAY_AMT1 0 0 1518
PAY_AMT2 689 1000 1500
PAY_AMT3 0 1000 1000
PAY_AMT4 0 1000 1000
PAY_AMT5 0 0 1000
PAY_AMT6 0 2000 5000
default payment next month 1 1 0
  • ID: ID of each client
  • LIMIT_BAL: Amount of given credit in NT dollars (includes individual and family/supplementary credit
  • SEX: Gender (1=male, 2=female)
  • EDUCATION: (1=graduate school, 2=university, 3=high school, 4=others, 5=unknown, 6=unknown)
  • MARRIAGE: Marital status (1=married, 2=single, 3=others)
  • AGE: Age in years
  • PAY_0: Repayment status in September, 2005 (-1=pay duly, 1=payment delay for one month, 2=payment delay for two months, … 8=payment delay for eight months, 9=payment delay for nine months and above)
  • PAY_2: Repayment status in August, 2005 (scale same as above)
  • PAY_3: Repayment status in July, 2005 (scale same as above)
  • PAY_4: Repayment status in June, 2005 (scale same as above)
  • PAY_5: Repayment status in May, 2005 (scale same as above)
  • PAY_6: Repayment status in April, 2005 (scale same as above)
  • BILL_AMT1: Amount of bill statement in September, 2005 (NT dollar)
  • BILL_AMT2: Amount of bill statement in August, 2005 (NT dollar)
  • BILL_AMT3: Amount of bill statement in July, 2005 (NT dollar)
  • BILL_AMT4: Amount of bill statement in June, 2005 (NT dollar)
  • BILL_AMT5: Amount of bill statement in May, 2005 (NT dollar)
  • BILL_AMT6: Amount of bill statement in April, 2005 (NT dollar)
  • PAY_AMT1: Amount of previous payment in September, 2005 (NT dollar)
  • PAY_AMT2: Amount of previous payment in August, 2005 (NT dollar)
  • PAY_AMT3: Amount of previous payment in July, 2005 (NT dollar)
  • PAY_AMT4: Amount of previous payment in June, 2005 (NT dollar)
  • PAY_AMT5: Amount of previous payment in May, 2005 (NT dollar)
  • PAY_AMT6: Amount of previous payment in April, 2005 (NT dollar)
  • default.payment.next.month: Default payment (1=yes, 0=no)
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30000 entries, 0 to 29999
Data columns (total 25 columns):
 #   Column                      Non-Null Count  Dtype
---  ------                      --------------  -----
 0   ID                          30000 non-null  int64
 1   LIMIT_BAL                   30000 non-null  int64
 2   SEX                         30000 non-null  int64
 3   EDUCATION                   30000 non-null  int64
 4   MARRIAGE                    30000 non-null  int64
 5   AGE                         30000 non-null  int64
 6   PAY_0                       30000 non-null  int64
 7   PAY_2                       30000 non-null  int64
 8   PAY_3                       30000 non-null  int64
 9   PAY_4                       30000 non-null  int64
 10  PAY_5                       30000 non-null  int64
 11  PAY_6                       30000 non-null  int64
 12  BILL_AMT1                   30000 non-null  int64
 13  BILL_AMT2                   30000 non-null  int64
 14  BILL_AMT3                   30000 non-null  int64
 15  BILL_AMT4                   30000 non-null  int64
 16  BILL_AMT5                   30000 non-null  int64
 17  BILL_AMT6                   30000 non-null  int64
 18  PAY_AMT1                    30000 non-null  int64
 19  PAY_AMT2                    30000 non-null  int64
 20  PAY_AMT3                    30000 non-null  int64
 21  PAY_AMT4                    30000 non-null  int64
 22  PAY_AMT5                    30000 non-null  int64
 23  PAY_AMT6                    30000 non-null  int64
 24  default payment next month  30000 non-null  int64
dtypes: int64(25)
memory usage: 5.7 MB
corr_matrix = df.corr()
corr_matrix
ID LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 PAY_4 ... BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6 default payment next month
ID 1.000000 0.026179 0.018497 0.039177 -0.029079 0.018678 -0.030575 -0.011215 -0.018494 -0.002735 ... 0.040351 0.016705 0.016730 0.009742 0.008406 0.039151 0.007793 0.000652 0.003000 -0.013952
LIMIT_BAL 0.026179 1.000000 0.024755 -0.219161 -0.108139 0.144713 -0.271214 -0.296382 -0.286123 -0.267460 ... 0.293988 0.295562 0.290389 0.195236 0.178408 0.210167 0.203242 0.217202 0.219595 -0.153520
SEX 0.018497 0.024755 1.000000 0.014232 -0.031389 -0.090874 -0.057643 -0.070771 -0.066096 -0.060173 ... -0.021880 -0.017005 -0.016733 -0.000242 -0.001391 -0.008597 -0.002229 -0.001667 -0.002766 -0.039961
EDUCATION 0.039177 -0.219161 0.014232 1.000000 -0.143464 0.175061 0.105364 0.121566 0.114025 0.108793 ... -0.000451 -0.007567 -0.009099 -0.037456 -0.030038 -0.039943 -0.038218 -0.040358 -0.037200 0.028006
MARRIAGE -0.029079 -0.108139 -0.031389 -0.143464 1.000000 -0.414170 0.019917 0.024199 0.032688 0.033122 ... -0.023344 -0.025393 -0.021207 -0.005979 -0.008093 -0.003541 -0.012659 -0.001205 -0.006641 -0.024339
AGE 0.018678 0.144713 -0.090874 0.175061 -0.414170 1.000000 -0.039447 -0.050148 -0.053048 -0.049722 ... 0.051353 0.049345 0.047613 0.026147 0.021785 0.029247 0.021379 0.022850 0.019478 0.013890
PAY_0 -0.030575 -0.271214 -0.057643 0.105364 0.019917 -0.039447 1.000000 0.672164 0.574245 0.538841 ... 0.179125 0.180635 0.176980 -0.079269 -0.070101 -0.070561 -0.064005 -0.058190 -0.058673 0.324794
PAY_2 -0.011215 -0.296382 -0.070771 0.121566 0.024199 -0.050148 0.672164 1.000000 0.766552 0.662067 ... 0.222237 0.221348 0.219403 -0.080701 -0.058990 -0.055901 -0.046858 -0.037093 -0.036500 0.263551
PAY_3 -0.018494 -0.286123 -0.066096 0.114025 0.032688 -0.053048 0.574245 0.766552 1.000000 0.777359 ... 0.227202 0.225145 0.222327 0.001295 -0.066793 -0.053311 -0.046067 -0.035863 -0.035861 0.235253
PAY_4 -0.002735 -0.267460 -0.060173 0.108793 0.033122 -0.049722 0.538841 0.662067 0.777359 1.000000 ... 0.245917 0.242902 0.239154 -0.009362 -0.001944 -0.069235 -0.043461 -0.033590 -0.026565 0.216614
PAY_5 -0.022199 -0.249411 -0.055064 0.097520 0.035629 -0.053826 0.509426 0.622780 0.686775 0.819835 ... 0.271915 0.269783 0.262509 -0.006089 -0.003191 0.009062 -0.058299 -0.033337 -0.023027 0.204149
PAY_6 -0.020270 -0.235195 -0.044008 0.082316 0.034345 -0.048773 0.474553 0.575501 0.632684 0.716449 ... 0.266356 0.290894 0.285091 -0.001496 -0.005223 0.005834 0.019018 -0.046434 -0.025299 0.186866
BILL_AMT1 0.019389 0.285430 -0.033642 0.023581 -0.023472 0.056239 0.187068 0.234887 0.208473 0.202812 ... 0.860272 0.829779 0.802650 0.140277 0.099355 0.156887 0.158303 0.167026 0.179341 -0.019644
BILL_AMT2 0.017982 0.278314 -0.031183 0.018749 -0.021602 0.054283 0.189859 0.235257 0.237295 0.225816 ... 0.892482 0.859778 0.831594 0.280365 0.100851 0.150718 0.147398 0.157957 0.174256 -0.014193
BILL_AMT3 0.024354 0.283236 -0.024563 0.013002 -0.024909 0.053710 0.179785 0.224146 0.227494 0.244983 ... 0.923969 0.883910 0.853320 0.244335 0.316936 0.130011 0.143405 0.179712 0.182326 -0.014076
BILL_AMT4 0.040351 0.293988 -0.021880 -0.000451 -0.023344 0.051353 0.179125 0.222237 0.227202 0.245917 ... 1.000000 0.940134 0.900941 0.233012 0.207564 0.300023 0.130191 0.160433 0.177637 -0.010156
BILL_AMT5 0.016705 0.295562 -0.017005 -0.007567 -0.025393 0.049345 0.180635 0.221348 0.225145 0.242902 ... 0.940134 1.000000 0.946197 0.217031 0.181246 0.252305 0.293118 0.141574 0.164184 -0.006760
BILL_AMT6 0.016730 0.290389 -0.016733 -0.009099 -0.021207 0.047613 0.176980 0.219403 0.222327 0.239154 ... 0.900941 0.946197 1.000000 0.199965 0.172663 0.233770 0.250237 0.307729 0.115494 -0.005372
PAY_AMT1 0.009742 0.195236 -0.000242 -0.037456 -0.005979 0.026147 -0.079269 -0.080701 0.001295 -0.009362 ... 0.233012 0.217031 0.199965 1.000000 0.285576 0.252191 0.199558 0.148459 0.185735 -0.072929
PAY_AMT2 0.008406 0.178408 -0.001391 -0.030038 -0.008093 0.021785 -0.070101 -0.058990 -0.066793 -0.001944 ... 0.207564 0.181246 0.172663 0.285576 1.000000 0.244770 0.180107 0.180908 0.157634 -0.058579
PAY_AMT3 0.039151 0.210167 -0.008597 -0.039943 -0.003541 0.029247 -0.070561 -0.055901 -0.053311 -0.069235 ... 0.300023 0.252305 0.233770 0.252191 0.244770 1.000000 0.216325 0.159214 0.162740 -0.056250
PAY_AMT4 0.007793 0.203242 -0.002229 -0.038218 -0.012659 0.021379 -0.064005 -0.046858 -0.046067 -0.043461 ... 0.130191 0.293118 0.250237 0.199558 0.180107 0.216325 1.000000 0.151830 0.157834 -0.056827
PAY_AMT5 0.000652 0.217202 -0.001667 -0.040358 -0.001205 0.022850 -0.058190 -0.037093 -0.035863 -0.033590 ... 0.160433 0.141574 0.307729 0.148459 0.180908 0.159214 0.151830 1.000000 0.154896 -0.055124
PAY_AMT6 0.003000 0.219595 -0.002766 -0.037200 -0.006641 0.019478 -0.058673 -0.036500 -0.035861 -0.026565 ... 0.177637 0.164184 0.115494 0.185735 0.157634 0.162740 0.157834 0.154896 1.000000 -0.053183
default payment next month -0.013952 -0.153520 -0.039961 0.028006 -0.024339 0.013890 0.324794 0.263551 0.235253 0.216614 ... -0.010156 -0.006760 -0.005372 -0.072929 -0.058579 -0.056250 -0.056827 -0.055124 -0.053183 1.000000

25 rows × 25 columns

df.columns
Index(['ID', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0',
       'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',
       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',
       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6',
       'default payment next month'],
      dtype='object')
df.corrwith(df['default payment next month']).sort_values(ascending=False)
default payment next month    1.000000
PAY_0                         0.324794
PAY_2                         0.263551
PAY_3                         0.235253
PAY_4                         0.216614
PAY_5                         0.204149
PAY_6                         0.186866
EDUCATION                     0.028006
AGE                           0.013890
BILL_AMT6                    -0.005372
BILL_AMT5                    -0.006760
BILL_AMT4                    -0.010156
ID                           -0.013952
BILL_AMT3                    -0.014076
BILL_AMT2                    -0.014193
BILL_AMT1                    -0.019644
MARRIAGE                     -0.024339
SEX                          -0.039961
PAY_AMT6                     -0.053183
PAY_AMT5                     -0.055124
PAY_AMT3                     -0.056250
PAY_AMT4                     -0.056827
PAY_AMT2                     -0.058579
PAY_AMT1                     -0.072929
LIMIT_BAL                    -0.153520
dtype: float64
plt.figure(figsize=(16, 16))
sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='Blues')
plt.show()

png

df.columns
Index(['ID', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0',
       'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',
       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',
       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6',
       'default payment next month'],
      dtype='object')
from sklearn.model_selection import train_test_split
X = df.drop(['ID', 'default payment next month'], axis=1)
y = df['default payment next month']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(X_train.shape, y_train.shape)
(24000, 23) (24000,)
  • 상관관계가 높은 BILL_AMT1 ~ BILL_AMT6 6개 열의 변동 비율
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

bill_columns = ['BILL_AMT' + str(i) for i in range(1, 7)]


scaler = StandardScaler()
bill_columns_sclaed = scaler.fit_transform(X_train[bill_columns])

pca = PCA(n_components = 2)
bill_columns_scaled_pca = pca.fit_transform(bill_columns_sclaed)
pca.explained_variance_ratio_
array([0.90358461, 0.05178874])
  • 상관관계가 높은 PAY_0, PAY_2 ~ PAY_6 6개 열의 변동 비율
pay_columns = ['PAY_' + str(i) for i in range(0, 7) if i != 1]


scaler = StandardScaler()
pay_columns_sclaed = scaler.fit_transform(X_train[pay_columns])

pca = PCA(n_components = 2)
pay_columns_scaled_pca = pca.fit_transform(pay_columns_sclaed)
print(pca.explained_variance_ratio_)
[0.71700378 0.11658426]
  • 원본 데이터로 모델 성능 측정하기
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

rf_clf = RandomForestClassifier(random_state=42)
scores = cross_val_score(rf_clf, X_train, y_train, scoring='accuracy', cv=3)
scores
array([0.8155  , 0.818125, 0.810125])
  • PCA로 압축된 데이터로 모델 성능 측정하기
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

pca = PCA(n_components=8)
X_train_scaled_pca = pca.fit_transform(X_train_scaled)
X_train_scaled_pca.shape
(24000, 8)
rf_clf = RandomForestClassifier(random_state=42)
scores = cross_val_score(rf_clf, X_train_scaled_pca, y_train, scoring='accuracy', cv=3)
scores
array([0.796875, 0.795125, 0.797125])
pca.explained_variance_ratio_ # 각 component의 분산 비율
array([0.28465193, 0.17799685, 0.06810077, 0.06421023, 0.04460868,
       0.04197626, 0.03963272, 0.0382991 ])
pca.explained_variance_ratio_.sum()
0.7594765578111686
X_train_scaled_pca.shape
(24000, 8)

Reference

댓글남기기