본문 바로가기
AI 공부/딥러닝

TabNet 정리

by AI Sonny 2022. 12. 12.
728x90

TabNet이란?

TabNet은 정형데이터를 위한 딥러닝으로 DT-based 모델과 DNN의 장점을 계승시킨 모델이다.

 

특징

1. 전처리가 필요없고, 경사하강법을 사용하여 end-to-end 학습에 유연하게 적용이 가능하다.

2. Sequential attention을 사용하여 feature 선택의 이유를 추적할 수 있게하여 interpretability를 확보했다.

3. 다른 도메인의 회귀와 분류 데이터셋에서 매우 높은 성능을 보인다.

4. 정형 데이터셋에서 비정형 사전학습이 성능을 크게 향상시킬 수 있다.

 

사용방법

TabNetClassifierTabNetRegressor 

from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor

clf = TabNetClassifier()  #TabNetRegressor()
clf.fit(
  X_train, Y_train,
  eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)

 

TabNetMultiTaskClassifier 

 

from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
clf = TabNetMultiTaskClassifier()
clf.fit(
  X_train, Y_train,
  eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)

 

eval_metric

  • binary classification metrics : 'auc', 'accuracy', 'balanced_accuracy', 'logloss'
  • multiclass classification : 'accuracy', 'balanced_accuracy', 'logloss'
  • regression: 'mse', 'mae', 'rmse', 'rmsle'

TabNet을 실제로 대회에서 써봣는데 시간이 매우 오래걸렸다.

 

따라서 데이터의 양이  많지 않으면 머신러닝으로 하는 것이 시간이나 효율적인 측면에서 좋은 것 같다.

 

그리고 TabNet은 스케일링을 따로 해줄 필요가 없는 것이 신기하였다.

 

그리고 학습하기 전에는 꼭 넘파이로(to_numpy()) 변환해주어야 한다.

 

코드는 깃허브에서 확인이 가능하다.

 

 

GitHub - DoItSon/Dacon

Contribute to DoItSon/Dacon development by creating an account on GitHub.

github.com

 

728x90

'AI 공부 > 딥러닝' 카테고리의 다른 글

pytorch 오류  (0) 2022.12.23
평가 지표  (0) 2022.11.07
딥러닝 (RNN,LSTM,GRU)  (1) 2022.10.03
딥러닝 (손실함수)  (0) 2022.10.02
딥러닝 (pytorch)  (0) 2022.10.02

댓글