XAI - (3) 국소적 대리 모형(local surrogate model) - LIME(Local Interpretable Model-agnostic Explanations)

목차

    관련 글 목록

    XAI - (1) 그래프를 이용한 방법 - PDP(Partial  Dependence Plot)

    XAI - (2) 대리 모형(surrogate model)을 이용한 방법 - global surrogate model

    XAI - (3) 국소적 대리 모형(local surrogate model) - LIME(Local Interpretable Model-agnostic Explanations)


    0. Recap

    이전 글: XAI - (2) 대리 모형(surrogate model)을 이용한 방법 - global surrogate model

    이전 글에서는 전역적 대리 모형(global surrogate model)에 대해 알아보았다. 전역적 대리 모형은 관측치의 전부 또는 일부, 입력변수의 전부 또는 일부를 활용하여 블랙박스 모형을 잘 모방하는 해석 가능한 모형을 학습시킨 것이다. 블랙박스 모형 전체를 대신할 모형을 만드는 것이기 때문에 전역적(global)이라고 하는 것인데 해석 가능한 모형으로 블랙박스 모형을 얼마나 잘 모방할 수 있을지에 따라 대리 모형의 신뢰도가 달라지게 된다. 대리 모형의 신뢰도 평가는 보통 \(R^2\) 값을 통해 이루어지는데 이 값이 얼마 이상이어야 좋다고 할 수 있는지는 절대적인 기준이 있는게 아니어서 분야에 따라, 데이터에 따라, 분석을 하는 상황에 따라 천차만별일 수 있다. 그래도 \(R^2\) 값을 확인해 보는 것이 터무니 없는 설명을 하는 대리 모형을 걸러낼 수 있다는 점에서는 상당히 유용하다고 할 수 있다.


    1. 국소적 대리 모형(local surrogate model) - LIME

    예측 모형을 재개발 하다보면 기존 모형(이하 ASIS 모형)은 잘 맞추는데 새로 만든 모형(이하 TOBE 모형)에서는 예측을 잘 못하는 경우가 발생할 수 있다. 이때 모델러는 TOBE 모형이 잘 맞추지 못하는 특정 관측값을 모형이 왜 그렇게 예측하였는지 궁금해할 가능성이 높다. 이런 경우에 사용해볼 수 있는 것이 국소적 대리 모형이다. 국소적 대리 모형(local surrogate model)은 블랙박스 모형의 개별 관측치를 설명하기 위한 모형을 뜻한다. LIME(Local Interpretable Model-agnostic Explations)은 다양한 종류의 데이터(Tabular, Text, Image)에 대해 국소적 대리 모형을 생성하기 위해 제안된 방법론이다[각주:1].

    LIME의 핵심 가정은 비선형적인 패턴을 학습한 머신러닝 모형이라도 국소적으로 보면 선형 모형으로 설명할 수 있다는 것이다.

    github.com/marcotcr/lime

    위의 그림에서 분홍색과 파랑색 음영으로 구분되는 영역은 모형의 decision function을 나타내며, 비선형적인 패턴을 보이고 있다. 볼드 표시된 빨간색 십자가는 우리가 설명하고 싶어하는 데이터 포인트이다. decision boundary 주변에 보이는 크기가 다른 십자가와 동그라미는 샘플링된 데이터이고, 관심있는 데이터 포인트와의 거리에 따른 가중치를 크기로 나타낸 것이다. 결과적으로 가중치가 부여된 샘플링 데이터로 선형 모형(회색 점선으로 표시됨)을 학습시켜서 관심있는 데이터 포인트를 설명하는데 사용한다. 이 선형 모형은 전역적으로 볼 때는 아닐 수 있지만 관심있는 데이터 포인트 근방에서는 원래 모형을 잘 근사하는 모형으로 볼 수 있다.

    LIME은 블랙박스 모형에 raw data를 변형하여 투입했을 때 어떤 현상이 발생하는지를 통해 블랙박스 모형의 특정 예측을 설명한다. LIME을 통해 국소적 대리 모형을 학습시키는 과정은 다음과 같다.

    [Step.1] 블랙박스 모형의 예측 결과를 확인한 후 설명이 필요해보이는 관측값 선택한다.

    [Step.2] 입력변수 raw data를 변형시켜서 얻은 샘플들을 블랙박스 모형의 입력변수로 투입하여 예측값을 얻는다.

    [Step.3] 관심있는 관측값에 대한 근접성(proximity)에 따라 Step.2에서 얻은 샘플들에 가중치를 부여한다.

    [Step.4] Step.2에서 얻은 샘플을 입력변수로 해당 샘플을 투입하여 얻어진 블랙박스 모형의 예측값을 타겟변수로 하여 Step.3에서 구한 가중치가 적용된 해석 가능한 모형을 학습시킨다.

    [Step.5] Step.4에서 학습된 모형을 해석하여 관심있는 관측값에 대한 블랙박스 모형의 예측을 설명한다.

     

    [Step.2]에서 raw data를 변형시키는 방법은 데이터의 종류에 따라 다르다.

    • Tabula data
      • 연속형 변수: 해당 변수의 평균과 분산을 파라미터로 가지는 정규분포(\(N(\mu_{p}, \sigma_{p}^2)\) such that \(\mu_{p}, \sigma_{p}^2\) are the mean and variance of input variable \(x_p\))에서 샘플링. python lime 패키지에서의 코드 구현은 \(N(0,1)\)에서 샘플을 뽑아 표준편차를 곱하고 평균을 더하는 역산을 수행하는 것으로 되어 있음. LimeTabularExplainer 함수에서 sample_around_instance=True로 설정하면 평균을 더하는 대신 해당 인스턴스의 변수값을 더해줌.
      • 범주형 변수: 카테고리 별 분포에 따라 샘플링한 후 설명하고자 하는 관측값의 카테고리와 일치하면 1, 일치하지 않으면 0의 값을 갖는 이진 변수로 변환
    • Text data: 임의의 단어 마스킹
    • Image data: 인접 픽셀(super pixels) 마스킹

    [Step.3]에서 커널을 따로 지정해주지 않으면 exponential kernel을 사용한다.

            if kernel_width is None:
                kernel_width = np.sqrt(training_data.shape[1]) * .75
            kernel_width = float(kernel_width)
    
            if kernel is None:
                def kernel(d, kernel_width):
                    return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

    2. Tabular Data에 LIME 적용하기

    LIME paper의 예제가 텍스트와 이미지 데이터로 되어 있는 것을 보면 Tabular 데이터 보다는 텍스트나 이미지 데이터에 적용하려고 만들어진 방법론인 것 같다. 대리 모형을 학습시키기 위한 데이터셋을 만드는 과정도 Tabular 데이터에 대해서는 뭔가 그럴듯하긴 한데 이렇게 해도 되나 싶은 생각이 든다. 하지만 평소에 텍스트나 이미지 데이터를 다룰일이 없기도 하고 크게 관심이 없어서 Tabular Data에 대한 적용 예시를 작성해보려고 한다. 

    이번 글에서도 지난글과 마찬가지로 캐글의 IEEE-CIS Fraud Detection 데이터셋을 사용하여 예제코드를 작성하였다.

    # 필요한 모듈 임포트
    import pandas as pd
    import numpy as np
    from sklearn.model_selection import train_test_split
    from lime.lime_tabular import LimeTabularExplainer
    from sklearn.preprocessing import StandardScaler
    from sklearn.preprocessing import OneHotEncoder
    
    # 데이터 전처리
    train_id = pd.read_csv('train_identity.csv')
    train_tr = pd.read_csv('train_transaction.csv')
    n_null_train_id = train_id.isnull().sum()
    n_null_train_tr = train_tr.isnull().sum()
    train_id = train_id[list(n_null_train_id[n_null_train_id==0].index)]
    train_tr = train_tr[list(n_null_train_tr[n_null_train_tr==0].index)]
    
    X_train = pd.merge(train_tr.drop(columns=['isFraud']), train_id, how='left', on='TransactionID')
    X_train['id_01'] = X_train['id_01'].fillna(0)
    X_train['id_12'] = X_train['id_12'].fillna('NULL')
    X_train = X_train.drop(columns=['TransactionID', 'TransactionDT'])
    
    cat_cols = X_train.select_dtypes('O').columns.tolist()
    num_cols = X_train.select_dtypes('number').columns.tolist()
    
    X_tr, X_vl, y_tr, y_vl = train_test_split(X_train, train_tr.isFraud, train_size=0.8)
    del train_id, train_tr, X_train
    
    # one-hot encoding for categorical variables
    ohe = OneHotEncoder(sparse=False)
    ohe.fit(X_tr[cat_cols])
    cat_cols_pp = []
    for i in range(len(cat_cols)):
        cat_cols_pp+=[cat_cols[i]+'_'+cate for cate in ohe.categories_[i]]
        
    X_tr_ohe = ohe.transform(X_tr[cat_cols])
    X_tr_ohe = pd.DataFrame(X_tr_ohe, 
                            index=X_tr.index,
                            columns=cat_cols_pp)
    X_vl_ohe = ohe.transform(X_vl[cat_cols])
    X_vl_ohe = pd.DataFrame(X_vl_ohe,
                            index=X_vl.index,
                            columns=cat_cols_pp)
    
    X_tr_pp = pd.concat([X_tr[num_cols], X_tr_ohe], axis=1)
    X_vl_pp = pd.concat([X_vl[num_cols], X_vl_ohe], axis=1)
    
    del X_tr_ohe, X_vl_ohe

     

    lime 패키지에서 사용하는 모듈에 대한 사용법은 공식문서를 참고하면 된다.

    처음에는 전처리를 알아서 해주는 줄 알고 전처리하기 전 데이터를 넣어서 실행해보았는데, training 데이터를 통으로 standard scaler에 넣는 과정이 있어서 이렇게 하면 ValueError를 반환한다.

    explainer = LimeTabularExplainer(training_data = X_tr.values, # numpy array 형태로 넣어줘야함
                                     mode='classification',
                                     training_labels = y_tr,
                                     feature_names = X_tr.columns,
                                     class_names=['Non-Fraud', 'Fraud'],
                                     categorical_features=[i for i, col in enumerate(X_tr.columns) if col in cat_cols], 
                                     categorical_names=cat_cols, 
                                     discretize_continuous = True,
                                     sample_around_instance = True)
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    ~\AppData\Local\Temp/ipykernel_26632/4224315680.py in <module>
    ----> 1 explainer = LimeTabularExplainer(training_data = X_tr.values, # numpy array 형태로 넣어줘야함
          2                                  mode='classification',
          3                                  training_labels = y_tr,
          4                                  feature_names = X_tr.columns,
          5                                  class_names=['Non-Fraud', 'Fraud'],
    
    ~\anaconda3\envs\iml\lib\site-packages\lime\lime_tabular.py in __init__(self, training_data, mode, training_labels, feature_names, categorical_features, categorical_names, kernel_width, kernel, verbose, class_names, feature_selection, discretize_continuous, discretizer, sample_around_instance, random_state, training_data_stats)
        256         # Though set has no role to play if training data stats are provided
        257         self.scaler = sklearn.preprocessing.StandardScaler(with_mean=False)
    --> 258         self.scaler.fit(training_data)
        259         self.feature_values = {}
        260         self.feature_frequencies = {}
    
    ~\anaconda3\envs\iml\lib\site-packages\sklearn\preprocessing\_data.py in fit(self, X, y, sample_weight)
        728         # Reset internal state before fitting
        729         self._reset()
    --> 730         return self.partial_fit(X, y, sample_weight)
        731 
        732     def partial_fit(self, X, y=None, sample_weight=None):
    
    ~\anaconda3\envs\iml\lib\site-packages\sklearn\preprocessing\_data.py in partial_fit(self, X, y, sample_weight)
        764         """
        765         first_call = not hasattr(self, "n_samples_seen_")
    --> 766         X = self._validate_data(X, accept_sparse=('csr', 'csc'),
        767                                 estimator=self, dtype=FLOAT_DTYPES,
        768                                 force_all_finite='allow-nan', reset=first_call)
    
    ~\anaconda3\envs\iml\lib\site-packages\sklearn\base.py in _validate_data(self, X, y, reset, validate_separately, **check_params)
        419             out = X
        420         elif isinstance(y, str) and y == 'no_validation':
    --> 421             X = check_array(X, **check_params)
        422             out = X
        423         else:
    
    ~\anaconda3\envs\iml\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
         61             extra_args = len(args) - len(all_args)
         62             if extra_args <= 0:
    ---> 63                 return f(*args, **kwargs)
         64 
         65             # extra_args > 0
    
    ~\anaconda3\envs\iml\lib\site-packages\sklearn\utils\validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
        671                     array = array.astype(dtype, casting="unsafe", copy=False)
        672                 else:
    --> 673                     array = np.asarray(array, order=order, dtype=dtype)
        674             except ComplexWarning as complex_warning:
        675                 raise ValueError("Complex data not supported\n"
    
    ValueError: could not convert string to float: 'R'

    explainer = LimeTabularExplainer(training_data = X_tr_pp.values, # numpy array 형태로 넣어줘야함
                                     mode='classification',
                                     training_labels = y_tr,
                                     feature_names = X_tr_pp.columns,
                                     class_names=['Non-Fraud', 'Fraud'],
                                     categorical_features=[i for i, col in enumerate(X_tr_pp.columns) if col in cat_cols_pp], 
                                     categorical_names=cat_cols_pp, 
                                     discretize_continuous=True,
                                     sample_around_instance=True)

    discretize_continuous=True 옵션을 사용하면 연속형 변수를 구간화시켜서 대리모형을 학습한다. 구간화 방법을 따로 설정해줄 수도 있는데 디폴트 옵션은 사분위수를 기준으로 자르는 것이다(discretizer='quartile'). 연속형 변수를 구간화 시키는 것이 설명할 때 더 쉽기 때문에 보통 대리모형을 사용하거나 설명의 목적으로 모형을 만들때는 연속형 변수를 구간화 시키는 경우가 많은 것 같다.

    sample_around_instance=True 옵션을 사용하면 설명하고자 하는 관측치를 기준으로 Normal centering을 수행한다. 관측치를 중심으로 샘플을 뿌리는 것이 더 합리적이라는 생각이 들어 이 옵션을 사용하였다.

    y_vl[y_vl==1].head()
    
    124982    1
    311817    1
    162251    1
    404372    1
    494069    1
    Name: isFraud, dtype: int64

    설명할 관측치는 타겟이 Fraud인 건들 중에서 하나를 골랐다.

    _explainer = explainer.explain_instance(X_vl_pp.loc[124982,:], mdl.predict_proba, num_features=10)

    결과를 확인하려면 explaniner로 show_in_notebook 메소드를 사용하면 된다. 주피터 노트북에 결과를 요약하여 보여주는 함수인데 다크 테마에서는 결과물이 잘 보이지 않는 단점이 있다.

    _ = _explainer.as_pyplot_figure(label=1)

    as_pyplot_figure함수를 사용하면 대리 모형의 계수값을 수평 막대그래프로 볼 수 있다.

    LIME 방법론의 단점 중 하나는 같은 관측값에 대해서도 대리모형을 만들때마다 결과값이 크게 달라질 수 있다는 것이다.

    # 다시 실행
    _explainer = explainer.explain_instance(X_vl_pp.loc[124982,:], mdl.predict_proba, num_features=10)

    이 문제를 보완하는 가장 쉬운 방법은 샘플 사이즈를 키우는 것이다. 샘플 사이즈는 explain_instance 메소드에서 num_samples를 조절해주면 된다. 디폴트 값은 5000인데 10배인 50000으로 설정하여 다시 대리모형을 학습시켜보면

    # num_samples의 디폴트 값은 5000
    _explainer = explainer.explain_instance(X_vl_pp.loc[124982,:], 
                                            mdl.predict_proba, 
                                            num_features=10,
                                            num_samples=50000)

    # 다시 실행
    _explainer = explainer.explain_instance(X_vl_pp.loc[124982,:], 
                                            mdl.predict_proba, 
                                            num_features=10,
                                            num_samples=50000)

    샘플 사이즈를 늘려서 대리모형을 학습시키면 계수 값이 약간 다르긴 하지만 거의 비슷한 패턴을 보이는 것을 알 수 있다.


    3. 마치며

    모형 전체를 설명하려고 하지 않고, '설명하고 싶은 관측값을 정해서 그 주변을 잘 설명하는 대리모형을 만든다'라는 LIME의 컨셉은 어느 정도 설득력이 있어보인다. 또한 모형에 관계없이 사용할 수 있는 방법(model-agnostic method)이고 Tabular, Text, Image에 모두 적용할 수 있기 때문에 확장성이 좋다. 하지만 Tabular 데이터에 대한 샘플 생성시 변수간 상관관계를 고려하지 않고 정규분포에서 샘플링하는 문제가 있고[각주:2], 시뮬레이션을 시도할 때마다 설명이 크게 달라질 수 있다는 문제도 있다[각주:3]. 또한, 관측치에 대한 인접 영역을 정의하기 위해 데이터가 바뀔때마다 커널 폭을 튜닝해줘야 한다는 점도 단점이라면 단점일 수 있다.

    이번 글에서는 python으로 예제 코드를 작성하였는데 R에서의 적용이 궁금하다면 여기를 참고해보면 될 것 같다.

    1. Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin (2016), "Why Should I Trust You? Explaining the Predictions of Any Classifier", ACM's Conference on Knowledge Discovery and Data Mining, KDD2016 [본문으로]
    2. 석사 논문을 준비 할 때 변수간 상관관계를 고려해서 샘플링하는 방법을 적용해보려고 했는데 너무 마이너한 개선점이라는 피드백을 받고 다른 주제로 방향을 튼 적이 있다. [본문으로]
    3. 위의 예시에서 살펴본 것처럼 샘플 사이즈를 키우면 어느 정도 안정화시킬 수 있긴 하다. [본문으로]