์๋ ํ์ธ์, ์ ๋ Python์ ๋ณ๋ ฌ ๋ฐ ๋ถ์ฐ ์ปดํจํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ธ Dask ์ ์ ์์ ๋๋ค. ์ด ์ปค๋ฎค๋ํฐ ๋ด์์ ๋ณ๋ ฌ ๊ต์ก ๋๋ ETL์ ์ํด Dask์์ XGBoost๋ฅผ ๋ฐฐํฌํ๋ ๋ฐ ํ๋ ฅํ๋ ๋ฐ ๊ด์ฌ์ด ์๋์ง ๊ถ๊ธํฉ๋๋ค.
์ด ํ๋ก์ ํธ์ ๊ด๋ จ๋ Dask์ ๋ ๊ฐ์ง ๊ตฌ์ฑ ์์๊ฐ ์์ ์ ์์ต๋๋ค.
์ฌ๊ธฐ์ ํ๋ ฅ์ ๊ด์ฌ์ด ์์ต๋๊น?
@mrocklin Dask๊ฐ sklearn๊ณผ ํตํฉ๋์ด ์๋ค๊ณ ์๊ฐํ์ต๋๋ค. sklearn ๋ํผ๊ฐ ์๋ํ๋์ง ํ์ธํ์ จ๋์?
๋ถ์ฐ ์์คํ ๊ณผ ์๋ฏธ ์๊ฒ ํตํฉํ๋ ค๋ฉด ์ผ๋ฐ์ ์ผ๋ก ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์์ค์ด ์๋๋ผ ์๊ณ ๋ฆฌ์ฆ ์์ค์์ ์ํํด์ผ ํฉ๋๋ค. SKLearn๊ณผ Dask๊ฐ ์๋ก๋ฅผ ๋์ธ ์ ์๋ ๋ช ๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ง๋ง, ํน๋ณํ ๊น์ง๋ ์์ต๋๋ค.
Dask ๋ฐ์ดํฐ ํ๋ ์์ ์ข์ ์์์ด ๋ ๊ฒ์ ๋๋ค. ์ฝ๋ ๋ฒ ์ด์ค์๋ pandas ๋ฐ์ดํฐ ํ๋ ์์ด ์๋์ง ํ์ธํฉ๋๋ค. ๊ทธ๊ฒ์ด dask ๋ฐ์ดํฐ ํ๋ ์์ด ์์ํ๊ธฐ์ ์ ํฉํ ๊ณณ์ผ ์ ์์ต๋๋ค.
๋๊ตฐ๊ฐ๊ฐ ๋ฉํฐ ํ ๋ผ๋ฐ์ดํธ dask ๋ฐ์ดํฐ ํ๋ ์์ ๊ฐ์ง๊ณ ๋์ฐฉํ๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ๊ทธ๋ฅ ํ๋ค๋ก ๋ณํํ๊ณ ์งํํ์๋์? ์๋๋ฉด dask ๋ฐ์ดํฐ ํ๋ ์์ ๊ตฌ์ฑํ๋ ๋ค์ํ pandas ๋ฐ์ดํฐ ํ๋ ์์ ๊ฐ๋ฆฌํค๋ฉด์ ํด๋ฌ์คํฐ ์ ์ฒด์์ XGBoost๋ฅผ ์ง๋ฅ์ ์ผ๋ก ๋ณ๋ ฌํํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
์ฌ์ฉ์๊ฐ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ์ง์ ํ ์ ์์ต๋๊น? ์ฌ์ฉ์๊ฐ partial_fit์ ํตํด ํํ์ ๋ฐ์ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
cc @tqchen ์ ์ฝ๋์ ๋ถ์ฐ ๋ถ๋ถ์ ๋ ์ต์ํฉ๋๋ค.
xgboost์ ๋ถ์ฐ ๋ฒ์ ์ ๋ถ์ฐ ์์ ์คํ๊ธฐ์ ์ฐ๊ฒฐํ ์ ์์ผ๋ฉฐ ์ด์์ ์ผ๋ก๋ xgboost์ ๋ฐ์ดํฐ ํํฐ์ ํผ๋๋ฅผ ๊ฐ์ ธ์จ ๋ค์ ๊ณ์ํฉ๋๋ค.
@mrocklin ๊ฐ์ฅ ๊ด๋ จ์ฑ์ด ๋์ ๋ถ๋ถ์ xgboost-spark ๋ฐ xgboost-flink ๋ชจ๋์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค. ์ด ๋ชจ๋์ xgboost๋ฅผ spark/flink์ mapPartition ๊ธฐ๋ฅ์ ํฌํจํฉ๋๋ค. Dask์๋ ๋น์ทํ ๊ฒ ์์ ๊ฒ ๊ฐ์์
xgboost ์ธก์ ์๊ตฌ ์ฌํญ์ XGBoost๊ฐ rabit์ ์ํ ํ๋ก์ธ์ค ๊ฐ ์ฐ๊ฒฐ์ ์ฒ๋ฆฌํ๊ณ ํด๋ผ์ด์ธํธ ์ธก์์ (๊ฐ ์์ ์ ์ฐ๊ฒฐํ๋) ์ถ์ ๊ธฐ๋ฅผ ์์ํด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค.
https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala#L112 ์์ ๊ด๋ จ ์ฝ๋๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
Rabit์ ๋ค๋ฅธ ๋ถ์ฐ ์์คํ ์ ๋ด์ฅ๋๋๋ก ์ค๊ณ๋์ด ์๊ธฐ ๋๋ฌธ์ ํ์ด์ฌ ์ธก์์ ์กฐ์ ํ๋ ๊ฒ์ ๊ทธ๋ฆฌ ์ด๋ ต์ง ์์ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
Dask์์ ๋ค๋ฅธ ๋ถ์ฐ ์์คํ ์ ์์ํ๋ ๊ฒ์ ์ผ๋ฐ์ ์ผ๋ก ๊ฝค ํ ์ ์์ต๋๋ค. ํธ์คํ ๋ถ์ฐ ์์คํ (spark/flink/dask)์์ xg-boost๋ก ๋ฐ์ดํฐ๋ฅผ ์ด๋ป๊ฒ ์ด๋ํฉ๋๊น? ์๋๋ฉด ์๊ท๋ชจ ๋ฐ์ดํฐ์ ๋ํ ๋ถ์ฐ ๊ต์ก์ ์ํ ๊ฒ์ ๋๊น?
๋ณด๋ค ๊ตฌ์ฒด์ ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ์์คํ ์ ๊ตฌ์ถํ ์์ ์ ๋๋ค.
์ด๊ฒ์ด ๋น์ ์ ๊ธฐ๋์ ์ผ์นํฉ๋๊น? ๊ด๋ จ Python API๋ฅผ ์๋ ค์ฃผ๊ธฐ ์ฝ์ต๋๊น?
์, ์ฌ๊ธฐ์์ ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐธ์กฐํ์ญ์์ค. https://github.com/dmlc/xgboost/blob/master/tests/distributed/ for python API.
์ถ๊ฐ๋ก ํด์ผ ํ ์ผ์ ๋๋ผ์ด๋ฒ ์ธก(dask๋ฅผ ๊ตฌ๋ํ๋ ์ฅ์์ผ ๊ฐ๋ฅ์ฑ์ด ์์)์์ ํ ๋ผ ์ถ์ ๊ธฐ๋ฅผ ์์ํ๋ ๊ฒ์ ๋๋ค. ์ด๊ฒ์ https://github.com/dmlc/dmlc-core ์ dmlc-submit ์คํฌ๋ฆฝํธ์์ ์ํ๋ฉ๋๋ค.
์๊ฒ ์ต๋๋ค. ์ด์ ์ ๊ฐ์๋ฅผ ์์ฑํฉ๋๋ค.
๋๋ผ์ด๋ฒ/์ค์ผ์ค๋ฌ ๋ ธ๋์์ ํ ๋ผ ์ถ์ ๊ธฐ๋ฅผ ์์ํฉ๋๋ค.
envs = {'DMLC_NUM_WORKER' : nworker,
'DMLC_NUM_SERVER' : nserver}
rabit = RabitTracker(hostIP=ip_address, nslave=num_workers)
envs.update(rabit.slave_envs())
rabit.start(args.num_workers) # manages connections in background thread
PSTracker
๋ฅผ ์์ํ๊ธฐ ์ํด ๋น์ทํ ๊ณผ์ ์ ๊ฑฐ์น ์๋ ์์ต๋๋ค. ์ด๊ฒ์ ๋์ผํ ์ค์ ์ง์ค์ ์์คํ
์ ์์ด์ผ ํฉ๋๊น ์๋๋ฉด ๋คํธ์ํฌ ๋ด์ ๋ค๋ฅธ ๊ณณ์ ์์ด์ผ ํฉ๋๊น? ์ด ๋ช ๊ฐ์ง๊ฐ ์์ด์ผ ํฉ๋๊น? ์ฌ์ฉ์๊ฐ ๊ตฌ์ฑํ ์ ์์ด์ผ ํฉ๋๊น?
๊ฒฐ๊ตญ ๋ด ํธ๋์ปค(๋ฐ pstrackers?)๊ฐ rabit ๋คํธ์ํฌ์ ๊ฐ์ ํ๊ณ ์ฐจ๋จํ๊ฒ ๋ฉ๋๋ค.
rabit.join() # join network
์์
์ ๋
ธ๋์์ ์ด๋ฌํ ํ๊ฒฝ ๋ณ์(์ผ๋ฐ dask ์ฑ๋์ ํตํด ์ด๋ํจ)๋ฅผ ๋ก์ปฌ ํ๊ฒฝ์ผ๋ก ๋คํํด์ผ ํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ xgboost.rabit.init()
๋ฅผ ํธ์ถํ๋ ๊ฒ์ผ๋ก ์ถฉ๋ถํฉ๋๋ค.
import os
os.environ.update(envs)
xgboost.rabit.init()
Rabit ์ฝ๋๋ฅผ ๋ณด๋ฉด ํ๊ฒฝ ๋ณ์๊ฐ ์ด ์ ๋ณด๋ฅผ ์ ๊ณตํ๋ ์ ์ผํ ๋ฐฉ๋ฒ์ธ ๊ฒ์ฒ๋ผ ๋ณด์ ๋๋ค. ์ด๊ฒ์ ํ์ธํ ์ ์์ต๋๊น? ํธ๋์ปค ํธ์คํธ/ํฌํธ ์ ๋ณด๋ฅผ ์ง์ ์ ๋ ฅ์ผ๋ก ์ ๊ณตํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
๊ทธ๋ฐ ๋ค์ ๋ด numpy ๋ฐฐ์ด/pandas ๋ฐ์ดํฐ ํ๋ ์/scipy ํฌ์ ๋ฐฐ์ด์ DMatrix ๊ฐ์ฒด๋ก ๋ณํํฉ๋๋ค. ์ด๊ฒ์ ๋น๊ต์ ๊ฐ๋จํด ๋ณด์ ๋๋ค. ๊ทธ๋ฌ๋ ์์ ์๋น ์ฌ๋ฌ ๋ฐ์ดํฐ ๋ฐฐ์น๊ฐ ์์ ์ ์์ต๋๋ค. ๋ ๋ง์ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ ์ ์๊ฒ ๋๋ฉด train์ ์ฌ๋ฌ ๋ฒ ํธ์ถํ๋ ๊น๋ํ ๋ฐฉ๋ฒ์ด ์์ต๋๊น? ๋ค์ ์ค์ ๋ํ ์๊ฒฌ์ด ๊ฑฑ์ ๋ฉ๋๋ค.
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
ํ๋ จ์ ์์ํ๊ธฐ ์ ์ ๋ชจ๋ ๋ฐ์ดํฐ๊ฐ ๋์ฐฉํ ๋๊น์ง ๊ธฐ๋ค๋ ค์ผ ํฉ๋๊น?
์์ ๋ชจ๋ ๊ฒ์ด ์ ํํ๋ค๊ณ ๊ฐ์ ํ๋ฉด ์ฌ๋๋ค์ด ๋ฐ๋ชจ์ ์ฌ์ฉํ๋ ํ์ค ๋ถ์ฐ ๊ต์ก ์์ ๊ฐ ์์ต๋๊น?
pstracker๋ฅผ ์์ํ ํ์๊ฐ ์์ต๋๋ค.
์ค๋ ์์นจ์ ๊ฐ์ง๊ณ ๋ ์๊ฐ์ด ์ข ์์์ด์. ๊ฒฐ๊ณผ: https://github.com/mrocklin/dask-xgboost
์ง๊ธ๊น์ง๋ ๋จ์ผ ๋ฉ๋ชจ๋ฆฌ ๋ด ๋ฐ์ดํฐ ์ธํธ์ ๋ถ์ฐ ํ์ต๋ง ์ฒ๋ฆฌํฉ๋๋ค. ๋ช ๊ฐ์ง ์ง๋ฌธ์ด ์๊ฒผ์ต๋๋ค.
rabit.init
์ ์ธ์์ ์ด๋ป๊ฒ ๋งคํ๋ฉ๋๊น? rabit.init
์ ๋ํ ์์ ์
๋ ฅ ํ์์ ์ ํํ ๋ฌด์์
๋๊น? slave_envs()
์ ๊ฒฐ๊ณผ๋ฅผ rabit.init์ ์ ๋ฌํ๋ฉด ๋ชฉ๋ก์ด ํ์ํ๊ธฐ ๋๋ฌธ์ ๋ถ๋ช
ํ ์๋ํ์ง ์์ต๋๋ค. ๊ฐ ํค ์ด๋ฆ์ --key
๋ก ๋ณํํด์ผ ํฉ๋๊น, ์๋ง๋ DMLC
์ ๋์ฌ๋ฅผ ์ญ์ ํ๊ณ ์๋ฌธ์๋ก ๋ณํํด์ผ ํ ๊น์?rabit.init(['DMLC_KEY1=VALUE1', 'DMLC_KEY2=VALUE2']
์ด๊ฒ์ด ์ด๋ป๊ฒ ์ฌ์ฉ๋๋์ง์ ๋ํ ์ผ๋ฐ์ ์ผ๋ก ๋ ๊ฐ์ง ์ง๋ฌธ์ด ๋ ์์ต๋๋ค(์ ๋ XGBoost์ ๋ํ ๊ฒฝํ์ด ์๊ณ ๊ธฐ๊ณ ํ์ต์ ๋ํ ์ฝ๊ฐ์ ๊ฒฝํ๋ง ์์ต๋๋ค. ์ ๋ฌด์ง๋ฅผ ์ฉ์ํด ์ฃผ์ญ์์ค).
์ด๋ค ์ฌ์ฉ ์ฌ๋ก๊ฐ ๋ ์ผ๋ฐ์ ์ ๋๊น?
๊ฐ ์์ ์ ๋ฐ์ดํฐ์ ๋ค๋ฅธ ํํฐ์ (ํ๋ณ)์์ ์๋ํด์ผ ํ๋ฉฐ ๋์ผํ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ณด๋ฉด ์ ๋ฉ๋๋ค.
์ด๊ฒ์ ์ผ๋ฐ์ ์ผ๋ก spark/flink์ ๊ฐ์ ํ๋ ์์ํฌ์ mapPartition ์์ ์ ํด๋นํฉ๋๋ค.
๋ด ๋ฐ์ดํฐ ์ธํธ์ 8๊ฐ์ ํ, 4๊ฐ์ ์ด์ด ์๋ค๊ณ ๊ฐ์ ํด ๋ณด๊ฒ ์ต๋๋ค. ๋ ๋ช ์ ์์ ์๋ฅผ ์์ํ๋ฉด
์ข์, ์ง๊ธ ๊ฑฐ๊ธฐ ์๋ ๊ฒ์ด ์กฐ๊ธ ๋ ๊นจ๋ํด์ก์ต๋๋ค. ๊ฐ ์์ ์์์ ์์ฑ๋ ๊ฒฐ๊ณผ๋ฅผ ์๋นํ ์ ์๋ ๊ธฐ๋ฅ์ด ์์ผ๋ฉด ์ข๊ฒ ์ง๋ง ์ง๊ธ์ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ต๋๋ค. ํ์ฌ ์๋ฃจ์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
์ด ์๋ฃจ์ ์ ๊ด๋ฆฌ ๊ฐ๋ฅํ ๊ฒ์ฒ๋ผ ๋ณด์ด์ง๋ง ์ด์์ ์ด์ง๋ ์์ต๋๋ค. xgboost-python์ด ๋์ฐฉํ์ ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์๋ค์ผ ์ ์๋ค๋ฉด ํธ๋ฆฌํ ๊ฒ์ ๋๋ค. ๊ทธ๋ฌ๋ ๋ค์์ผ๋ก ํด์ผ ํ ์ผ์ ์ค์ ์์ ์๋ํด ๋ณด๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
์๋ฅผ ๋ค์ด ์ธํฐ๋ท์์ ์ฐพ์๋ณด๊ฒ ์ต๋๋ค. ๋๊ตฐ๊ฐ ๋ด๊ฐ numpy ๋๋ pandas API๋ก ์ฝ๊ฒ ์์ฑํ ์ ์๋ ์ธ์์ ์ธ ๋ฌธ์ ๊ฐ ์๋ ๊ฒฝ์ฐ ํ์ํฉ๋๋ค. ๊ทธ๋๊น์ง ๋ฌด์์ ๋ฐ์ดํฐ๊ฐ ์๋ ๋ด ๋ฉํฑ์ ๊ฐ๋จํ ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
In [1]: import dask.dataframe as dd
In [2]: df = dd.demo.make_timeseries('2000', '2001', {'x': float, 'y': float, 'z': int}, freq='1s', partition_freq=
...: '1D') # some random time series data
In [3]: df.head()
Out[3]:
x y z
2000-01-01 00:00:00 0.778864 0.824796 977
2000-01-01 00:00:01 -0.019888 -0.173454 1023
2000-01-01 00:00:02 0.552826 0.051995 1083
2000-01-01 00:00:03 -0.761811 0.780124 959
2000-01-01 00:00:04 -0.643525 0.679375 980
In [4]: labels = df.z > 1000
In [5]: del df['z']
In [6]: df.head()
Out[6]:
x y
2000-01-01 00:00:00 0.778864 0.824796
2000-01-01 00:00:01 -0.019888 -0.173454
2000-01-01 00:00:02 0.552826 0.051995
2000-01-01 00:00:03 -0.761811 0.780124
2000-01-01 00:00:04 -0.643525 0.679375
In [7]: labels.head()
Out[7]:
2000-01-01 00:00:00 False
2000-01-01 00:00:01 True
2000-01-01 00:00:02 True
2000-01-01 00:00:03 False
2000-01-01 00:00:04 False
Name: z, dtype: bool
In [8]: from dask.distributed import Client
In [9]: c = Client() # creates a local "cluster" on my laptop
In [10]: from dask_xgboost import train
/home/mrocklin/Software/anaconda/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
In [11]: param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'} # taken from example
In [12]: bst = train(c, param, df, labels)
/home/mrocklin/Software/anaconda/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
/home/mrocklin/Software/anaconda/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
/home/mrocklin/Software/anaconda/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
/home/mrocklin/Software/anaconda/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
[14:46:20] Tree method is automatically selected to be 'approx' for faster speed. to use old behavior(exact greedy algorithm on single machine), set tree_method to 'exact'
[14:46:20] Tree method is automatically selected to be 'approx' for faster speed. to use old behavior(exact greedy algorithm on single machine), set tree_method to 'exact'
[14:46:20] Tree method is automatically selected to be 'approx' for faster speed. to use old behavior(exact greedy algorithm on single machine), set tree_method to 'exact'
[14:46:20] Tree method is automatically selected to be 'approx' for faster speed. to use old behavior(exact greedy algorithm on single machine), set tree_method to 'exact'
In [13]: bst
Out[13]: <xgboost.core.Booster at 0x7fbaacfd17b8>
๋๊ตฌ๋ ์ง ์ดํด๋ณด๊ณ ์ถ๋ค๋ฉด ๊ด๋ จ ์ฝ๋๊ฐ ์์ต๋๋ค: https://github.com/mrocklin/dask-xgboost/blob/master/dask_xgboost/core.py
๋ด๊ฐ ๋งํ๋ฏ์ด, ๋๋ XGBoost๋ฅผ ์ฒ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ์๋ง๋ ๋์น ๋ถ๋ถ์ด ์์ ๊ฒ์ ๋๋ค.
์๋ํ ์ผ๋ฐ์ ์ธ ์ฅ๋๊ฐ ์๋ https://github.com/dmlc/xgboost/tree/master/demo/data ์ ์์ต๋๋ค.
libsvm ํ์์ด๋ฉฐ numpy๋ก ๊ฐ์ ธ์ค๋ ค๋ฉด ์ฝ๊ฐ์ ๊ตฌ๋ฌธ ๋ถ์์ด ํ์ํฉ๋๋ค.
๋ ํฐ ๊ฒ์ด ์์ต๋๊น(์ค์ ๋ก ํด๋ฌ์คํฐ๊ฐ ํ์ํ ๊ฒฝ์ฐ)? ์๋๋ฉด ์์์ ํฌ๊ธฐ์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์์ฑํ๋ ํ์ค ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
๋๋ ์๋ง๋ ๋ ๋์ ์ง๋ฌธ์ "๋น์ (๋๋ ์ด ํธ๋ฅผ ์ฝ๋ ๋ค๋ฅธ ์ฌ๋)์ด ์ฌ๊ธฐ์ ๋ณด๊ณ ์ถ์ ๊ฒ์ ๋ฌด์์ ๋๊น?"์ ๋๋ค.
์ง๊ธ ๊ฑด๋ฌผ ์์ธก. ๋ชจ๋ธ์ ์์
์๋ก ๋ค์ ์ด๋ํ๊ณ (ํผํด/ํผํด ํด์ ํ๋ก์ธ์ค๋ฅผ ํตํด) ์ผ๋ถ ๋ฐ์ดํฐ์ ๋ํด bst.predict
๋ฅผ ํธ์ถํ๋ฉด ๋ค์ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค.
Doing rabit call after Finalize
๋ด ๊ฐ์ ์ ์ด ์์ ์์ ๋ชจ๋ธ์ด ๋
๋ฆฝ์ ์ด๋ฉฐ ๋ ์ด์ rabit๋ฅผ ์ฌ์ฉํ ํ์๊ฐ ์๋ค๋ ๊ฒ์
๋๋ค. ํด๋ผ์ด์ธํธ ์ปดํจํฐ์์ ์ ์๋ํ๋ ๊ฒ ๊ฐ์ต๋๋ค. predict
๋ฅผ ํธ์ถํ ๋ ์ด ์ค๋ฅ๊ฐ ํ์๋๋ ์ด์ ๋ ๋ฌด์์
๋๊น?
์์ธก์ ์ผ๋ถ๋ ์ฌ์ ํ rabit๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ฃผ๋ก ์์ธก์๊ฐ ํ์ต๊ณผ ๊ณต์ ๋๋ ์ผ๋ถ ์ด๊ธฐํ ๋ฃจํด๊ณผ ํจ๊ป ํ์ต์๋ฅผ ๊ณ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ฒฐ๊ตญ ์ด๊ฒ์ ์์ ๋์ด์ผ ํ์ง๋ง ํ์ฌ๋ก์๋ ๊ทธ๋ ์ต๋๋ค.
๊ณตํต ๋ฐ์ดํฐ ์ธํธ์ ๋ํด ์ ์๋ํ๋ ํ ํฅ๋ฏธ๋ก์ด ์ถ๋ฐ์ ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
์ด์จ๋ ์ค๊ฐ ๋ฐ์ดํฐ์ ํด๋ฌ์คํฐ๋ฅผ ์ฌ์ฉํด์ผ ํ๋ ์ด์ ๊ฐ ์์ต๋๋ค(ํด๋ฌ์คํฐ ํ๊ฒฝ์์ ์ค์ผ์ค๋ง์ด ์ฉ์ดํจ). ์ผ๋ถ pyspark ์ฌ์ฉ์๋ ์ฐ๋ฆฌ๊ฐ ์กฐ๊ธ ๊ด๊ณ ํ๋ฉด ์๋ํด ๋ณผ ์ ์์ต๋๋ค.
์ค์ ๋ก ์ค์ํ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ ์คํธํ๋ ๊ฒ์ ์ด๋ ค์ ์ต๋๋ค(์: 10์ต ๊ฐ์ ํ์ด ์๋ 1๊ฐ์ ๋ฐ์ดํฐ ์ธํธ ์๋). Kaggle์ ๊ด๋ จ์ฑ์ด ์์ ์ ์๋ ์ฝ ์ฒ๋ง ๊ฐ์ ํฐ ๋ฐ์ดํฐ ์ธํธ์ผ ์ ์์ต๋๋ค.
์ด ๋ฆฌํฌ์งํ ๋ฆฌ ๋ ์์ฒ๋ง ๊ฐ์ ํ๊ณผ ์์ฒ๋ง ๊ฐ์ ์ด(์ ํซ ์ธ์ฝ๋ฉ ํ ์์ฒ?)์ ์๋ ํญ๊ณต์ฌ ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ์คํ์ ๋ณด์ฌ์ค๋๋ค. ๋ฒค์น๋งํฌ์ ๊ฒฝ์ฐ 100k ํ์ ์ํ์ ๊ฐ์ ธ ์์ ์ธ์์ ์ผ๋ก ์์ฑํ ๊ฒ์ฒ๋ผ ๋ณด์ ๋๋ค. ์ด ์ํ์ ๋ ํฐ ๋ฐ์ดํฐ ์ธํธ. ํ์ํ ๊ฒฝ์ฐ ์ด๋ฅผ ํ์ฅํ ์ ์์ ๊ฒ์ ๋๋ค.
๋ค์์ ๋จ์ผ ์ฝ์ด์์ pandas ๋ฐ xgboost์ ํจ๊ป ์ด ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ ์์ ๋๋ค. ๋ฐ์ดํฐ ์ค๋น, ๋งค๊ฐ๋ณ์ ๋๋ ์ด๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ชจ๋ ๊ถ์ฅ ์ฌํญ์ ํ์ํฉ๋๋ค.
In [1]: import pandas as pd
In [2]: df = pd.read_csv('train-0.1m.csv')
In [3]: df.head()
Out[3]:
Month DayofMonth DayOfWeek DepTime UniqueCarrier Origin Dest Distance \
0 c-8 c-21 c-7 1934 AA ATL DFW 732
1 c-4 c-20 c-3 1548 US PIT MCO 834
2 c-9 c-2 c-5 1422 XE RDU CLE 416
3 c-11 c-25 c-6 1015 OO DEN MEM 872
4 c-10 c-7 c-6 1828 WN MDW OMA 423
dep_delayed_15min
0 N
1 N
2 N
3 N
4 Y
In [4]: labels = df.dep_delayed_15min == 'Y'
In [5]: del df['dep_delayed_15min']
In [6]: df = pd.get_dummies(df)
In [7]: len(df.columns)
Out[7]: 652
In [8]: import xgboost as xgb
/home/mrocklin/Software/anaconda/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
In [9]: dtrain = xgb.DMatrix(df, label=labels)
In [10]: param = {} # Are there better choices for parameters? I could use help here
In [11]: bst = xgb.train(param, dtrain) # or other parameters here?
[17:50:28] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 124 extra nodes, 0 pruned nodes, max_depth=6
[17:50:30] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 120 extra nodes, 0 pruned nodes, max_depth=6
[17:50:32] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 120 extra nodes, 0 pruned nodes, max_depth=6
[17:50:33] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 116 extra nodes, 0 pruned nodes, max_depth=6
[17:50:35] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 112 extra nodes, 0 pruned nodes, max_depth=6
[17:50:36] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 114 extra nodes, 0 pruned nodes, max_depth=6
[17:50:38] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 106 extra nodes, 0 pruned nodes, max_depth=6
[17:50:39] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 116 extra nodes, 0 pruned nodes, max_depth=6
[17:50:41] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 104 extra nodes, 0 pruned nodes, max_depth=6
[17:50:43] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 100 extra nodes, 0 pruned nodes, max_depth=6
In [12]: test = pd.read_csv('test.csv')
In [13]: test.head()
Out[13]:
Month DayofMonth DayOfWeek DepTime UniqueCarrier Origin Dest Distance \
0 c-7 c-25 c-3 615 YV MRY PHX 598
1 c-4 c-17 c-2 739 WN LAS HOU 1235
2 c-12 c-2 c-7 651 MQ GSP ORD 577
3 c-3 c-25 c-7 1614 WN BWI MHT 377
4 c-6 c-6 c-3 1505 UA ORD STL 258
dep_delayed_15min
0 N
1 N
2 N
3 N
4 Y
In [14]: test_labels = test.dep_delayed_15min == 'Y'
In [16]: del test['dep_delayed_15min']
In [17]: test = pd.get_dummies(test)
In [18]: len(test.columns) # oops, looks like the columns don't match up
Out[18]: 670
In [19]: dtest = xgb.DMatrix(test)
In [20]: predictions = bst.predict(dtest) # this fails because of mismatched columns
์ด์จ๋ ์ฌ๊ธฐ์ ์ต์ ์ด ์์ต๋๋ค. ํญ๊ณต์ฌ ๋ฐ์ดํฐ ์ธํธ๋ ์ ์๋ ค์ง ๊ฒ์ฒ๋ผ ๋ณด์ด๋ฉฐ ์ค์ ๋ก๋ ๋ถํธํ ์ ๋๋ก ์ปค์ง ์ ์์ต๋๋ค. ๋ค์ ๋งํ์ง๋ง ๊ธฐ๊ณ ํ์ต์ ๋ด ์ ๋ฌธ์ด ์๋๋ฏ๋ก ์ด๊ฒ์ด ์ ์ ํ์ง ์๋์ง ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
cc @TomAugspurger , ์ด์ ๋ํด ์๊ฐํ๊ณ ์๋ ์ฌ๋์ฒ๋ผ ๋ณด์ ๋๋ค.
Dask ๋ฐ ์์ธก์ ๊ดํด์๋ ํญ์ ๋๋น์ ๋ค์ ์ค์ ํ ์ ์์ต๋๋ค. ์ด๊ฒ์ ๊ฒ์ผ๋ฅธ ์ํ๋ฅผ ์ ์งํ๋ ๋์ ํ๊ฐ๋ฅผ ๊ฐ์ ํ๊ธฐ ๋๋ฌธ์ ์ฝ๊ฐ ๋ถ์ ํํ๊ฒ ๋๊ปด์ง๋๋ค. ๊ทธ๋ฌ๋ ์ด๊ฒ์ ์ฌ์ฉํ๊ธฐ์ ์ฌ๊ฐํ ์ฐจ๋จ๊ธฐ๊ฐ ์๋๋๋ค.
์์ธก์ ๋ช ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ๋ ๊ฐ์ง ์ง๋ฌธ:
Booster.predict
๋ฅผ ์ฌ๋ฌ ๋ฒ ํธ์ถํ ์ ์์ต๋๊น?rabit.init
, Booster.predict
๋ฐ rabit.finalize
๋ฅผ ํธ์ถํ ์ ์์ต๋๊น?ํ์ฌ ์ ์ถ์ ๊ธฐ๋ฅผ ๋ง๋ค๊ณ ์์
์์ ๊ธฐ๋ณธ ์ค๋ ๋์์ rabit.init
๋ฅผ ํธ์ถํฉ๋๋ค. ์ด๊ฒ์ ์ ์๋ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์์
์ ์ค๋ ๋์์ Booster.predict
๋ฅผ ํธ์ถํ๋ฉด(๊ฐ dask ์์
์๋ ๊ณ์ฐ์ ์ํด ์ค๋ ๋ ํ์ ์ ์งํฉ๋๋ค) Doing rabit call after Finalize
์ ๊ฐ์ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ถ์ฒ ์ฌํญ์ด ์์ต๋๊น?
์์ธก์ ์ผ๋ถ๋ ์ฌ์ ํ rabit๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ฃผ๋ก ์์ธก์๊ฐ ํ์ต๊ณผ ๊ณต์ ๋๋ ์ผ๋ถ ์ด๊ธฐํ ๋ฃจํด๊ณผ ํจ๊ป ํ์ต์๋ฅผ ๊ณ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ฒฐ๊ตญ ์ด๊ฒ์ ์์ ๋์ด์ผ ํ์ง๋ง ํ์ฌ๋ก์๋ ๊ทธ๋ ์ต๋๋ค.
๋๋ ์ด๊ฒ์ด ๊ถ๊ธํ๋ค. ํ๋ จ๋ ๋ชจ๋ธ์ ์์ ์์์ ๋ด ํด๋ผ์ด์ธํธ ์ปดํจํฐ๋ก ์ง๋ ฌํ-์ ์ก-์ญ์ง๋ ฌํํ ํ์๋ ๋๋น ๋คํธ์ํฌ๊ฐ ์๋๋ผ๋ ์ผ๋ฐ ๋ฐ์ดํฐ์์ ์ ๋๋ก ์๋ํ๋ ๊ฒ ๊ฐ์ต๋๋ค. Rabit์ผ๋ก ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ Rabit ์์ด ๋ฐ์ดํฐ๋ฅผ ์์ธกํ ์ ์๋ ๊ฒ ๊ฐ์ต๋๋ค. ์ด๊ฒ์ ์์ฐ์์๋ ํ์ํ ๊ฒ ๊ฐ์ต๋๋ค. ์ฌ๊ธฐ์์ ๋๋น ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ์ ์ฝ ์กฐ๊ฑด์ ๋ํด ๋ ๋งํ ์ ์์ต๋๊น?
์์ ๋ฐ์ดํฐ์ธํธ/๋ฌธ์
์์ ๋ชจ๋ ๊ฒ์ด ์ ํํ๋ค๊ณ ๊ฐ์ ํ๋ฉด ์ฌ๋๋ค์ด ๋ฐ๋ชจ์ ์ฌ์ฉํ๋ ํ์ค ๋ถ์ฐ ๊ต์ก ์์ ๊ฐ ์์ต๋๊น?
์ด ์คํ์ ๊ฒฐ๊ณผ๋ฅผ ์ฌํํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
https://github.com/Microsoft/LightGBM/wiki/Experiments#parallel -์คํ
XGBoost(#1950)์ ์๋ก์ด binning + fast hist ์ต์ ์ ์ฌ์ฉํ๋ฉด ๋น์ทํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ต๋๋ค.
์๋ํ ์ผ๋ฐ์ ์ธ ์ฅ๋๊ฐ ์๋ https://github.com/dmlc/xgboost/tree/master/demo/data ์ ์์ต๋๋ค.
libsvm ํ์์ด๋ฉฐ numpy๋ก ๊ฐ์ ธ์ค๋ ค๋ฉด ์ฝ๊ฐ์ ๊ตฌ๋ฌธ ๋ถ์์ด ํ์ํฉ๋๋ค.
sklearn์์ ์ด PR์ ๊ด์ฌ์ด ์์ ์ ์์ต๋๋ค. https://github.com/scikit-learn/scikit-learn/pull/935
@mrocklin ๋ชจ๋ธ ์ฌ์ฌ์ฉ์ ์ ์ฝ์ด ์์ต๋๋ค. ๋ฐ๋ผ์ ๋ถ์ฐ ๋ฒ์ ์์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ง๋ ฌ ๋ฒ์ ์์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ํ์ฌ ์์ธก๊ธฐ์ ํ๊ณ(rabit๋ก ์ปดํ์ผํ ๋)๊ฐ ํ๋ จ ํจ์์ ํผํฉ๋ ๊ธฐ๋ฅ์ ๊ฐ์ง๊ณ ์๋ค๋ ๊ฒ์ ๋๋ค(๊ทธ๋์ rabit ํธ์ถ์ด ๋ฐ์ํ์ต๋๋ค).
์ด์ ๋ง์ํ์ ๋๋ก ๋ฌธ์ ์ ๋ํ ํด๊ฒฐ์ฑ
์ด ์์ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ๋จ์ํ rabit.init
(์๋ฌด๊ฒ๋ ์ ๋ฌํ์ง ์๊ณ ์์ธก์๊ฐ ์ด๊ฒ์ด ์ ์ผํ ์์
์๋ผ๊ณ ์๊ฐํ๊ฒ ํจ)๋ฅผ ์ํํ๋ฉด ์์ธก์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํด์ผ ํฉ๋๋ค.
๋ค. ์ค์ ๋ก ๊ทธ๊ฒ์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค. dask-xgboost๋ ์ด์ ์์ธก์ ์ง์ํฉ๋๋ค: https://github.com/mrocklin/dask-xgboost/commit/827a03d96977cda8d104899c9f42f52dac446165
@tqchen ํด๊ฒฐ ๋ฐฉ๋ฒ์ ๊ฐ์ฌ๋๋ฆฝ๋๋ค!
๋ค์์ ๋ก์ปฌ ๋ฉํฑ์ ์๋ Airlines ๋ฐ์ดํฐ ์ธํธ์ ์์ ์ํ์์ dask.dataframe ๋ฐ xgboost๋ฅผ ์ฌ์ฉํ๋ ์ํฌํ๋ก์ ๋๋ค. ๋ชจ๋์๊ฒ ๊ด์ฐฎ์ ๋ณด์ด๋์? ์ฌ๊ธฐ์ ๋๋ฝ๋ XGBoost์ API ์์๊ฐ ์์ต๋๊น?
In [1]: import dask.dataframe as dd
In [2]: import dask_xgboost as dxgb
In [3]: df = dd.read_csv('train-0.1m.csv')
In [4]: df.head()
Out[4]:
Month DayofMonth DayOfWeek DepTime UniqueCarrier Origin Dest Distance \
0 c-8 c-21 c-7 1934 AA ATL DFW 732
1 c-4 c-20 c-3 1548 US PIT MCO 834
2 c-9 c-2 c-5 1422 XE RDU CLE 416
3 c-11 c-25 c-6 1015 OO DEN MEM 872
4 c-10 c-7 c-6 1828 WN MDW OMA 423
dep_delayed_15min
0 N
1 N
2 N
3 N
4 Y
In [5]: labels = df.dep_delayed_15min == 'Y'
In [6]: del df['dep_delayed_15min']
In [7]: df = df.categorize()
In [8]: df = dd.get_dummies(df)
In [9]: data_train, data_test = df.random_split([0.9, 0.1], random_state=123)
In [10]: labels_train, labels_test = labels.random_split([0.9, 0.1], random_state=123)
In [11]: from dask.distributed import Client
In [12]: client = Client() # in a large-data situation I probably should have done this before calling categorize above (which requires computation)
In [13]: param = {} # Are there better choices for parameters?
In [14]: bst = dxgb.train(client, {}, data_train, labels_train)
[14:00:46] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 120 extra nodes, 0 pruned nodes, max_depth=6
[14:00:48] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 120 extra nodes, 0 pruned nodes, max_depth=6
[14:00:50] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 122 extra nodes, 0 pruned nodes, max_depth=6
[14:00:53] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 118 extra nodes, 0 pruned nodes, max_depth=6
[14:00:55] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 120 extra nodes, 0 pruned nodes, max_depth=6
[14:00:57] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 114 extra nodes, 0 pruned nodes, max_depth=6
[14:00:59] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 118 extra nodes, 0 pruned nodes, max_depth=6
[14:01:01] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 118 extra nodes, 0 pruned nodes, max_depth=6
[14:01:04] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 94 extra nodes, 0 pruned nodes, max_depth=6
[14:01:06] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 102 extra nodes, 0 pruned nodes, max_depth=6
In [15]: bst
Out[15]: <xgboost.core.Booster at 0x7f689803af60>
In [16]: predictions = dxgb.predict(client, bst, data_test)
In [17]: predictions
Out[17]:
Dask Series Structure:
npartitions=1
None float32
None ...
Name: predictions, dtype: float32
Dask Name: _predict_part, 9 tasks
์ ๋จ๊ธฐ ๋ชฉํ๋ ์ด์ ๋ํ ์งง์ ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผ์ ์์ฑํ์ฌ XGBoost์ ๋ํ ๋ ๋ง์ ๊ฒฝํ๊ณผ ๋ ๋ง์ ์๊ฐ์ ๊ฐ์ง ๋ค๋ฅธ ๋๊ตฐ๊ฐ๊ฐ ์ด ํ๋ก์ ํธ๋ฅผ ์ฑํํ๊ณ ์ถ์งํ ์ ์๋๋ก ํ๋ ๊ฒ์ ๋๋ค. (์ ๋ ์ฌ๊ธฐ ์๋ ๋ค๋ฅธ ๋ชจ๋ ์ฌ๋๋ค๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ์ด์ ๊ฐ์ ๋ช ๊ฐ์ง ๋ค๋ฅธ ํ๋ก์ ํธ๋ฅผ ๋์์ ์งํํ๊ณ ์์ต๋๋ค.)
๋๋ ์ด๋ฏธ S3 ๋ฒํท์ ์ ์ฅ๋์ด ์๊ธฐ ๋๋ฌธ์ Airlines ๋ฐ์ดํฐ ์ธํธ์ ๋ถ๋ถ์ ์ ๋๋ค. Criteo ๋ฐ์ดํฐ ์ธํธ๊ฐ ๊ท๋ชจ์ ๋ฐ๋ผ ๋ ๋์ ๋ฐ๋ชจ๋ฅผ ์ ๊ณตํ ๊ฒ์ด๋ผ๋ ์ ์๋ ๋์ํฉ๋๋ค.
์ด๋ค ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํด์ผ ํ๋์ง ๋๋ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ป๊ฒ ํ๋จํด์ผ ํ๋์ง ์์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. ๋งค๊ฐ๋ณ์์ ๊ฒฝ์ฐ ์ฌ๊ธฐ ์์ @szilard ์ ์คํ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ธก์ ํ๋จํ๋ ์ข์ ๋ฐฉ๋ฒ์ด ์์ต๋๊น? ์๋ฅผ ๋ค์ด labels_test
์ ์ผ์นํ๋ predictions > 0.5
๋ฅผ ์ฐพ๊ณ ์์ต๋๊น?
์๋ง๋ ์ด์ง ๋ถ๋ฅ(ํนํ ์ฐ๊ตฌ ๋๋ ๊ฒฝ์ ์ค์ ์์)์ ๋ํ ์์ธก ์ฑ๋ฅ์ ํ๊ฐํ๋ ๊ฐ์ฅ ์ผ๋ฐ์ ์ธ ๋ฐฉ๋ฒ์ ROC ๊ณก์ ์๋ ์์ญ(AUC)์ ์ฌ์ฉํ๋ ๊ฒ์ด์ง๋ง ์ค์ ์์ฉ ํ๋ก๊ทธ๋จ์์๋ "๋น์ฆ๋์ค" ๊ฐ๊ณผ ์ผ์นํ๋ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ ์ํฉ๋๋ค.
์๋ฅผ ๋ค์ด label_test์ ์ผ์นํ๋๋ก 0.5๋ณด๋ค ํฐ ์์ธก์ ์ฐพ๊ณ ์์ต๋๊น?
๋ค. ํ ์คํธ ์ธํธ์์ ๊ทธ ํ๊ท ์ ์ทจํ๋ฉด ์ด๊ฒ์ด ํ ์คํธ ์ ํ๋์ ๋๋ค. ๊ทธ๋ฌ๋ ๋ฐ์ดํฐ์ธํธ๊ฐ ๋ถ๊ท ํํ ๊ฐ๋ฅ์ฑ์ด ์์ต๋๋ค(ํด๋ฆญ๋ณด๋ค ํด๋ฆญ์ด ์๋ ๊ฒฝ์ฐ๊ฐ ํจ์ฌ ๋ง์). ์ด ๊ฒฝ์ฐ ROC AUC ์ ์๊ฐ ๋ ๋์ ์งํ์ ๋๋ค.
from sklearn.metrics import roc_auc_score
print(roc_auc_score(labels_test, predictions))
predictions
๊ฐ ํ
์คํธ ์ธํธ์ ๊ฐ ํ์ ๋ํด ๋ชจ๋ธ์ ์ํด ์ถ์ ๋ ์์ ํ๋ฅ ์ 1D ๋ฐฐ์ด์ด๋ผ๊ณ ๊ฐ์ ํฉ๋๋ค.
@mrocklin ํ ๊ฐ์ง ํ์ ์ง๋ฌธ์ dask๊ฐ ๋ค์ค ์ค๋ ๋ ์์ ์ ์์ ์ ํ์ฉํฉ๋๊น? ๋๋ ์ด๊ฒ์ด GIL๋ก ์ธํด ํ์ด์ฌ๊ณผ ๊ทธ๋ค์ง ๊ด๋ จ์ด ์๋ค๋ ๊ฒ์ ์๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ xgboost๋ ์์ ์๋น ๋ค์ค ์ค๋ ๋ ๊ต์ก์ ํ์ฉํ๋ฉด์ ์ฌ์ ํ ์๋ก ๋ถ์ฐ์ ์ผ๋ก ์กฐ์ ํ ์ ์์ต๋๋ค. ํญ์ xgboost์ nthread ์ธ์๋ฅผ ํด๋น ์์ ์์ ์์ ์ฝ์ด ์๋ก ์ค์ ํด์ผ ํฉ๋๋ค.
์งง์ ๋๋ต์ "์"์ ๋๋ค. Dask์ ๋๋ถ๋ถ์ NumPy, Pandas, SKLearn ๋ฐ Python์ผ๋ก ๋ํ๋ C ๋ฐ Fortran ์ฝ๋์ ๋ถ๊ณผํ ๊ธฐํ ํ๋ก์ ํธ์ ํจ๊ป ์ฌ์ฉ๋ฉ๋๋ค. GIL์ ์ด๋ฌํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ํฅ์ ๋ฏธ์น์ง ์์ต๋๋ค. ์ด๋ค ์ฌ๋๋ค์ PySpark RDD( dask.bag ์ฐธ์กฐ)์ ์ ์ฌํ ์์ฉ ํ๋ก๊ทธ๋จ์ Dask๋ฅผ ์ฌ์ฉํ๊ณ ์ํฅ์ ๋ฐ์ ๊ฒ์ ๋๋ค. ์ด ๊ทธ๋ฃน์ ์์์ ์ํฉ๋๋ค.
์, Dask๋ ๋ค์ค ์ค๋ ๋ ์์ ์ ํ์ฉํฉ๋๋ค. ๋ค์ค ์ค๋ ๋๋ฅผ ์ฌ์ฉํ๋๋ก XGBoost์ ์ด๋ป๊ฒ ์ง์ํฉ๋๊น? ์ง๊ธ๊น์ง์ ์คํ์์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ณ๊ฒฝํ์ง ์๊ณ ๋ ๋์ CPU ์ฌ์ฉ๋ฅ ์ ํ์ธํ๋๋ฐ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ชจ๋ ๊ฒ์ด ์ ์๋ํ ๊น์?
XGBoost๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ค์ค ์ค๋ ๋๋ฅผ ์ฌ์ฉํ๋ฉฐ nthread๊ฐ ์ค์ ๋์ง ์์ ๊ฒฝ์ฐ ๋จธ์ (ํด๋น ์์ ์ ๋์ )์์ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ CPU ์ค๋ ๋๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด๊ฒ์ ์ฌ๋ฌ ์์ ์๊ฐ ๋์ผํ ์์คํ ์ ํ ๋น๋ ๋ ๊ฒฝ์ ์กฐ๊ฑด์ ์์ฑํ ์ ์์ต๋๋ค.
๋ฐ๋ผ์ ํญ์ nthread ๋งค๊ฐ๋ณ์๋ฅผ ์์ ์๊ฐ ์ฌ์ฉํ ์ ์๋ ์ต๋ ์ฝ์ด ์๋ก ์ค์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ์์ ์๋น 4๊ฐ์ ์ค๋ ๋๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๋ฌผ๋ก ,
https://github.com/mrocklin/dask-xgboost/commit/c22d066b67c78710d5ad99b8620edc55182adc8f
2017๋
2์ 20์ผ ์์์ผ ์คํ 6์ 31๋ถ, Tianqi Chen ์๋ฆผ @github.com
์ผ๋ค:
XGBoost๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ค์ค ์ค๋ ๋๋ฅผ ์ฌ์ฉํ๋ฉฐ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
nthread๊ฐ ์ค์ ๋์ง ์์ ๊ฒฝ์ฐ ๋จธ์ ์ ์ค๋ ๋(ํด๋น ์์ ์ ๋์ ).
์ฌ๋ฌ ์์ ์๊ฐ ๋์ผํ ์์ ์ ํ ๋น๋ ๋ ๊ฒฝ์ ์กฐ๊ฑด์ด ์์ฑ๋ ์ ์์ต๋๋ค.
๊ธฐ๊ณ.๋ฐ๋ผ์ ํญ์ nthread ๋งค๊ฐ๋ณ์๋ฅผ ์ต๋ ๊ฐ์๋ก ์ค์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์์ ์๊ฐ ์ฌ์ฉํ ์ ์๋ ์ฝ์ด. ์ผ๋ฐ์ ์ผ๋ก ์ข์ ์ต๊ด์ say ์ฃผ์์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค.
์์ ์๋น ์ค๋ ๋ 4๊ฐโ
๋น์ ์ด ์ธ๊ธ๋์๊ธฐ ๋๋ฌธ์ ์ด๊ฒ์ ๋ฐ๋ ๊ฒ์ ๋๋ค.
์ด ์ด๋ฉ์ผ์ ์ง์ ๋ต์ฅํ๊ณ GitHub์์ ํ์ธํ์ธ์.
https://github.com/dmlc/xgboost/issues/2032#issuecomment-281205747 ๋๋ ์์๊ฑฐ
์ค๋ ๋
https://github.com/notifications/unsubscribe-auth/AASszPELRoeIvqEzyJhkKumIs-vd0PHiks5reiJngaJpZM4L_PXa
.
๋
ธํธ๋ถ: https://gist.github.com/19c89d78e34437e061876a9872f4d2df
์งง์ ์คํฌ๋ฆฐ์บ์คํธ(6๋ถ): https://youtu.be/Cc4E-PdDSro
๋นํ์ ํผ๋๋ฐฑ์ ๋งค์ฐ ํ์ํฉ๋๋ค. ๋ค์ ํ ๋ฒ ์ด ๋ถ์ผ์ ๋ํ ์ ์ ๋ฌด์ง๋ฅผ ์ฉ์ํด ์ฃผ์ญ์์ค.
@mrocklin ๋ฉ์ง ๋ฐ๋ชจ! param dict์์ 'tree_method': 'hist', 'grow_policy': 'lossguide'
๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฐํ์ ์ฑ๋ฅ(๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋)์ด ํฌ๊ฒ ํฅ์๋ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
@ogrisel๋ ๊ฐ์ฌํฉ๋๋ค. ์ด๋ฌํ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ๋ฉด ๊ต์ก ์๊ฐ์ด 6๋ถ์์ 1๋ถ์ผ๋ก ๋์ด๋ฉ๋๋ค. ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๊ฑฐ์ ๋์ผํ๊ฒ ์ ์ง๋๋ ๊ฒ ๊ฐ์ต๋๋ค.
์ข์, ์ด๊ฒ์ผ๋ก ๋์์ค์. ์ฐ๋ฆฌ๊ฐ ๊ตฌํํด์ผ ํ ๊ธฐ์ฐจ ๋ฐ ์์ธก ์ด์ธ์ XGBoost ์์ ์ด ์์ต๋๊น?
@tqchen ๋๋ @ogrisel ์ค ํ ๋ช ์ด https://github.com/mrocklin/dask-xgboost/blob/master/dask_xgboost/core.py ์์ ๊ตฌํ์ ์ดํด๋ณผ ์๊ฐ์ด ์๋ค๋ฉด ๊ฐ์ฌํ๊ฒ ์ต๋๋ค. ์ธ๊ตญ ์ฝ๋๋ฒ ์ด์ค๋ฅผ ์ดํด๋ณด๋ ๊ฒ์ด ์ฐ์ ์์ ๋ชฉ๋ก์์ ํญ์ ๋์ ๊ฒ์ ์๋๋ผ๋ ๊ฒ์ ์ดํดํฉ๋๋ค.
๋ชจ๋ ๊ฒ์ด ์ ์์ด๋ฉด README์ ์กฐ๊ธ ๋ ์ถ๊ฐํ๊ณ PyPI์ ๊ฒ์ํ๋ฉด ์ด ๋ฌธ์ ๋ฅผ ์ข ๋ฃํ ์ ์์ต๋๋ค.
๋๋ ํ๋ จํ๊ณ ์์ธกํ๋ ๊ฒ๋ง์ด ๋ฐฐํฌ๋์ด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ๋ค๋ฅธ ๊ฒ๋ค์ ๋ฐ์ดํฐ์ ์ ์๋ตํ์ง ์๊ธฐ ๋๋ฌธ์ ๋ฐฐํฌํ ํ์๊ฐ ์์ต๋๋ค.
dask-xgboost๋ฅผ PyPI๋ก ํธ์ํ๊ณ https://github.com/dask/dask-xgboost ๋ก ์ฎ๊ฒผ์ต๋๋ค.
๋์์ ์ฃผ์ @tqchen ๊ณผ @ogrisel ์๊ฒ ๊ฐ์ฌ๋๋ฆฝ๋๋ค. ํ์ ์ ํตํด ์ด๋ฅผ ๋น๊ต์ ์ฝ๊ฒ ์ํํ ์ ์์์ต๋๋ค.
๋ฒค์น๋งํฌ๋ฅผ ์คํํ๋ ค๋ ์ฌ๋๋ค์๊ฒ ๋์์ด ๋์์ผ๋ฉด ํฉ๋๋ค. ๊ทธ ์ ๊น์ง๋ ๋ซ์ต๋๋ค.
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
๋ ธํธ๋ถ: https://gist.github.com/19c89d78e34437e061876a9872f4d2df
์งง์ ์คํฌ๋ฆฐ์บ์คํธ(6๋ถ): https://youtu.be/Cc4E-PdDSro
๋นํ์ ํผ๋๋ฐฑ์ ๋งค์ฐ ํ์ํฉ๋๋ค. ๋ค์ ํ ๋ฒ ์ด ๋ถ์ผ์ ๋ํ ์ ์ ๋ฌด์ง๋ฅผ ์ฉ์ํด ์ฃผ์ญ์์ค.