728x90
TabNet이란?
TabNet은 정형데이터를 위한 딥러닝으로 DT-based 모델과 DNN의 장점을 계승시킨 모델이다.
특징
1. 전처리가 필요없고, 경사하강법을 사용하여 end-to-end 학습에 유연하게 적용이 가능하다.
2. Sequential attention을 사용하여 feature 선택의 이유를 추적할 수 있게하여 interpretability를 확보했다.
3. 다른 도메인의 회귀와 분류 데이터셋에서 매우 높은 성능을 보인다.
4. 정형 데이터셋에서 비정형 사전학습이 성능을 크게 향상시킬 수 있다.
사용방법
TabNetClassifier나 TabNetRegressor
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
clf = TabNetClassifier() #TabNetRegressor()
clf.fit(
X_train, Y_train,
eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)
TabNetMultiTaskClassifier
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
clf = TabNetMultiTaskClassifier()
clf.fit(
X_train, Y_train,
eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)
eval_metric
- binary classification metrics : 'auc', 'accuracy', 'balanced_accuracy', 'logloss'
- multiclass classification : 'accuracy', 'balanced_accuracy', 'logloss'
- regression: 'mse', 'mae', 'rmse', 'rmsle'
TabNet을 실제로 대회에서 써봣는데 시간이 매우 오래걸렸다.
따라서 데이터의 양이 많지 않으면 머신러닝으로 하는 것이 시간이나 효율적인 측면에서 좋은 것 같다.
그리고 TabNet은 스케일링을 따로 해줄 필요가 없는 것이 신기하였다.
그리고 학습하기 전에는 꼭 넘파이로(to_numpy()) 변환해주어야 한다.
코드는 깃허브에서 확인이 가능하다.
GitHub - DoItSon/Dacon
Contribute to DoItSon/Dacon development by creating an account on GitHub.
github.com
728x90
'AI 공부 > 딥러닝' 카테고리의 다른 글
pytorch 오류 (0) | 2022.12.23 |
---|---|
평가 지표 (0) | 2022.11.07 |
딥러닝 (RNN,LSTM,GRU) (1) | 2022.10.03 |
딥러닝 (손실함수) (0) | 2022.10.02 |
딥러닝 (pytorch) (0) | 2022.10.02 |
댓글