Scikit-learn: ๊ณ„์ธตํ™”๋œ ๊ทธ๋ฃนKFold

์— ๋งŒ๋“  2019๋…„ 04์›” 11์ผ  ยท  48์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: scikit-learn/scikit-learn

์„ค๋ช…

ํ˜„์žฌ sklearn์—๋Š” ๊ณ„์ธตํ™”๋œ ๊ทธ๋ฃน kfold ๊ธฐ๋Šฅ์ด ์—†์Šต๋‹ˆ๋‹ค. ๊ณ„์ธตํ™”๋ฅผ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ๊ทธ๋ฃน kfold๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋‘˜ ๋‹ค ์žˆ์œผ๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๊ฐ€ ๊ทธ๊ฒƒ์„ ๊ฐ–๊ธฐ๋กœ ๊ฒฐ์ •ํ•œ๋‹ค๋ฉด ๊ทธ๊ฒƒ์„ ๊ตฌํ˜„ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค.

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

๊ด€์‹ฌ ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค์ด ์ž์‹ ์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์™€ ์ด๋กœ๋ถ€ํ„ฐ ์ง„์ • ์›ํ•˜๋Š” ๊ฒƒ์ด ๋ฌด์—‡์ธ์ง€ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ธก์ •์„ ๋ฐ˜๋ณตํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ ์˜ํ•™ ๋ฐ ์ƒ๋ฌผํ•™์—์„œ ๋งค์šฐ ์ผ๋ฐ˜์ ์ธ ์‚ฌ์šฉ ์‚ฌ๋ก€์ž…๋‹ˆ๋‹ค.
์˜ˆ: MR ์ด๋ฏธ์ง€์—์„œ ์•Œ์ธ ํ•˜์ด๋จธ๋ณ‘(AD) ๋Œ€ ๊ฑด๊ฐ•ํ•œ ๋Œ€์กฐ๊ตฐ๊ณผ ๊ฐ™์€ ์งˆ๋ณ‘์„ ๋ถ„๋ฅ˜ํ•˜๋ ค๊ณ  ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. ๋™์ผํ•œ ์ฃผ์ œ์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ์Šค์บ”์ด ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์ถ”์  ์„ธ์…˜ ๋˜๋Š” ์ข…๋‹จ ๋ฐ์ดํ„ฐ์—์„œ). ์ด 1000๋ช…์˜ ํ”ผํ—˜์ž๊ฐ€ ์žˆ๊ณ  ๊ทธ ์ค‘ 200๋ช…์ด AD(๋ถˆ๊ท ํ˜• ํด๋ž˜์Šค)๋กœ ์ง„๋‹จ๋œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋Œ€๋ถ€๋ถ„์˜ ํ”ผ์‚ฌ์ฒด๋Š” ํ•œ ๋ฒˆ ์Šค์บ”ํ•˜์ง€๋งŒ ์ผ๋ถ€๋Š” 2~3๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ํ›ˆ๋ จ/ํ…Œ์ŠคํŠธํ•  ๋•Œ ๋ฐ์ดํ„ฐ ๋ˆ„์ถœ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ์ฃผ์ œ์˜ ์ด๋ฏธ์ง€๊ฐ€ ํ•ญ์ƒ ๋™์ผํ•œ ์ ‘ํž˜์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฅผ ์œ„ํ•ด StratifiedGroupKFold๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ข‹์Šต๋‹ˆ๋‹ค. ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜•์„ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ๊ณ„์ธตํ™”ํ•˜์ง€๋งŒ ์ฃผ์ œ๊ฐ€ ๋‹ค๋ฅธ ํด๋“œ์— ๋‚˜ํƒ€๋‚˜์ง€ ์•Š์•„์•ผ ํ•œ๋‹ค๋Š” ๊ทธ๋ฃน ์ œ์•ฝ ์กฐ๊ฑด์ด ์žˆ์Šต๋‹ˆ๋‹ค.
์ฃผ์˜: ๋ฐ˜๋ณต ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

์•„๋ž˜๋Š” kaggle-kernel ์—์„œ ์˜๊ฐ์„ ๋ฐ›์€ ๊ตฌํ˜„ ์˜ˆ์ž…๋‹ˆ๋‹ค.

import numpy as np
from collections import Counter, defaultdict
from sklearn.utils import check_random_state

class RepeatedStratifiedGroupKFold():

    def __init__(self, n_splits=5, n_repeats=1, random_state=None):
        self.n_splits = n_splits
        self.n_repeats = n_repeats
        self.random_state = random_state

    # Implementation based on this kaggle kernel:
    #    https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
    def split(self, X, y=None, groups=None):
        k = self.n_splits
        def eval_y_counts_per_fold(y_counts, fold):
            y_counts_per_fold[fold] += y_counts
            std_per_label = []
            for label in range(labels_num):
                label_std = np.std(
                    [y_counts_per_fold[i][label] / y_distr[label] for i in range(k)]
                )
                std_per_label.append(label_std)
            y_counts_per_fold[fold] -= y_counts
            return np.mean(std_per_label)

        rnd = check_random_state(self.random_state)
        for repeat in range(self.n_repeats):
            labels_num = np.max(y) + 1
            y_counts_per_group = defaultdict(lambda: np.zeros(labels_num))
            y_distr = Counter()
            for label, g in zip(y, groups):
                y_counts_per_group[g][label] += 1
                y_distr[label] += 1

            y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num))
            groups_per_fold = defaultdict(set)

            groups_and_y_counts = list(y_counts_per_group.items())
            rnd.shuffle(groups_and_y_counts)

            for g, y_counts in sorted(groups_and_y_counts, key=lambda x: -np.std(x[1])):
                best_fold = None
                min_eval = None
                for i in range(k):
                    fold_eval = eval_y_counts_per_fold(y_counts, i)
                    if min_eval is None or fold_eval < min_eval:
                        min_eval = fold_eval
                        best_fold = i
                y_counts_per_fold[best_fold] += y_counts
                groups_per_fold[best_fold].add(g)

            all_groups = set(groups)
            for i in range(k):
                train_groups = all_groups - groups_per_fold[i]
                test_groups = groups_per_fold[i]

                train_indices = [i for i, g in enumerate(groups) if g in train_groups]
                test_indices = [i for i, g in enumerate(groups) if g in test_groups]

                yield train_indices, test_indices

RepeatedStratifiedKFold (๊ฐ™์€ ๊ทธ๋ฃน์˜ ์ƒ˜ํ”Œ์ด ๋‘ ํด๋“œ์— ๋‚˜ํƒ€๋‚  ์ˆ˜ ์žˆ์Œ)์™€ RepeatedStratifiedGroupKFold ๋น„๊ต:

import matplotlib.pyplot as plt
from sklearn import model_selection

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ax.scatter(range(len(indices)), [ii + .5] * len(indices),
                   c=indices, marker='_', lw=lw, cmap=plt.cm.coolwarm,
                   vmin=-.2, vmax=1.2)

    ax.scatter(range(len(X)), [ii + 1.5] * len(X), c=y, marker='_',
               lw=lw, cmap=plt.cm.Paired)
    ax.scatter(range(len(X)), [ii + 2.5] * len(X), c=group, marker='_',
               lw=lw, cmap=plt.cm.tab20c)

    yticklabels = list(range(n_splits)) + ['class', 'group']
    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
           xlabel='Sample index', ylabel="CV iteration",
           ylim=[n_splits+2.2, -.2], xlim=[0, 100])
    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)


# demonstration
np.random.seed(1338)
n_splits = 4
n_repeats=5


# Generate the class/group data
n_points = 100
X = np.random.randn(100, 10)

percentiles_classes = [.4, .6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

# Evenly spaced groups
g = np.hstack([[ii] * 5 for ii in range(20)])


fig, ax = plt.subplots(1,2, figsize=(14,4))

cv_nogrp = model_selection.RepeatedStratifiedKFold(n_splits=n_splits,
                                                   n_repeats=n_repeats,
                                                   random_state=1338)
cv_grp = RepeatedStratifiedGroupKFold(n_splits=n_splits,
                                      n_repeats=n_repeats,
                                      random_state=1338)

plot_cv_indices(cv_nogrp, X, y, g, ax[0], n_splits * n_repeats)
plot_cv_indices(cv_grp, X, y, g, ax[1], n_splits * n_repeats)

plt.show()

RepeatedStratifiedGroupKFold_demo

๋ชจ๋“  48 ๋Œ“๊ธ€

@TomDLT @NicolasHug ์–ด๋–ป๊ฒŒ ์ƒ๊ฐํ•˜์„ธ์š”?

์ด๋ก ์ ์œผ๋กœ๋Š” ํฅ๋ฏธ๋กœ์šธ ์ˆ˜ ์žˆ์ง€๋งŒ ์‹ค์ œ๋กœ ์–ผ๋งˆ๋‚˜ ์œ ์šฉํ• ์ง€ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ํ™•์‹คํžˆ ๋ฌธ์ œ๋ฅผ ์—ด์–ด๋‘๊ณ  ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ์‚ฌ๋žŒ๋“ค์ด ์ด ๊ธฐ๋Šฅ์„ ์š”์ฒญํ•˜๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐ ๊ทธ๋ฃน์ด ๋‹จ์ผ ํด๋ž˜์Šค์— ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๊นŒ?

์ฐธ์กฐ #9413

@jnothman ์˜ˆ, ๋น„์Šทํ•œ ๊ฒƒ์„ ์—ผ๋‘์— ๋‘์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ pull ์š”์ฒญ์ด ์•„์ง ์—ด๋ ค ์žˆ์Œ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚˜๋Š” ๊ทธ๋ฃน์ด ํด๋“œ์—์„œ ๋ฐ˜๋ณต๋˜์ง€ ์•Š์„ ๊ฒƒ์ž„์„ ์˜๋ฏธํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฃน์œผ๋กœ ID๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ๋™์ผํ•œ ID๊ฐ€ ์—ฌ๋Ÿฌ ํด๋“œ์—์„œ ๋ฐœ์ƒํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

RFECV ์‚ฌ์šฉ๊ณผ ๊ด€๋ จ์ด ์žˆ์Œ์„ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค.
ํ˜„์žฌ ๊ธฐ๋ณธ์ ์œผ๋กœ StratifiedKFold cv๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์˜ fit()์€ ๋˜ํ•œ ๊ทธ๋ฃน์„ ์ทจํ•ฉ๋‹ˆ๋‹ค=
๊ทธ๋Ÿฌ๋‚˜ fit()์„ ์‹คํ–‰ํ•  ๋•Œ ๊ทธ๋ฃน์ด ์กด์ค‘๋˜์ง€ ์•Š๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค. ๊ฒฝ๊ณ  ์—†์Œ(๋ฒ„๊ทธ๋กœ ๊ฐ„์ฃผ๋  ์ˆ˜ ์žˆ์Œ).

๊ทธ๋ฃนํ™” ๋ฐ ๊ณ„์ธตํ™”๋Š” ๋ ˆ์ฝ”๋“œ ๊ฐ„ ์ข…์†์„ฑ์ด ์žˆ๋Š” ๋ถˆ๊ท ํ˜•ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.
(์ €์˜ ๊ฒฝ์šฐ ๊ฐ™์€ ๊ฐœ์ธ์ด ์—ฌ๋Ÿฌ ๊ธฐ๋ก์„ ๊ฐ€์ง€๊ณ  ์žˆ์ง€๋งŒ ๋ถ„ํ•  ์ˆ˜์— ๋น„ํ•ด ์—ฌ์ „ํžˆ ๋งŽ์€ ์ˆ˜์˜ ๊ทธ๋ฃน=์‚ฌ๋žŒ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋ถ„ํ•  ์ˆ˜).

๊ทธ๋ž˜์„œ: +1!

์ด๊ฒƒ์€ ํ™•์‹คํžˆ ์œ ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋ถˆ๊ท ํ˜•์ด ์‹ฌํ•œ ์‹œ๊ณ„์—ด ์˜๋ฃŒ ๋ฐ์ดํ„ฐ๋กœ ์ž‘์—…ํ•˜์—ฌ ํ™˜์ž๋ฅผ ๋ถ„๋ฆฌํ•˜์ง€๋งŒ ๊ฐ ์ ‘ํž˜์—์„œ ๋ถˆ๊ท ํ˜• ํด๋ž˜์Šค์˜ (๋Œ€๋žต) ๊ท ํ˜•์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ StratifiedKFold๋Š” ๊ทธ๋ฃน์„ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์‚ฌ์šฉํ•˜์ง€๋งŒ ๊ทธ๋ฃน์— ๋”ฐ๋ผ ๊ทธ๋ฃนํ™”ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ํ”Œ๋ž˜๊ทธ๋ฅผ ์ง€์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ธฐ๋Šฅ์˜ ๋˜ ๋‹ค๋ฅธ ์ข‹์€ ์‚ฌ์šฉ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋งค์šฐ ๋ถˆ๊ท ํ˜•ํ•œ ์žฌ๋ฌด ๋ฐ์ดํ„ฐ์ž…๋‹ˆ๋‹ค. ์ œ ๊ฒฝ์šฐ์—๋Š” ๋™์ผํ•œ ์—”ํ„ฐํ‹ฐ(๋‹ค๋ฅธ ์‹œ์ )์— ๋Œ€ํ•œ ์—ฌ๋Ÿฌ ๋ ˆ์ฝ”๋“œ๊ฐ€ ์žˆ๋Š” ๋งค์šฐ ๋ถˆ๊ท ํ˜•ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ๋ˆ„์ถœ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด GroupKFold ๋ฅผ ํ•˜๊ณ  ์‹ถ์ง€๋งŒ ๋†’์€ ๋ถˆ๊ท ํ˜•์œผ๋กœ ์ธํ•ด ์–‘์„ฑ์ด ๊ฑฐ์˜ ๋˜๋Š” ์ „ํ˜€ ์—†๋Š” ๊ทธ๋ฃน์œผ๋กœ ๋๋‚  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋˜ํ•œ ๊ณ„์ธตํ™”ํ•ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ #14524๋ฅผ ์ฐธ์กฐํ•˜์‹ญ์‹œ์˜ค.

Stratified GroupShuffleSplit ๋ฐ GroupKFold์˜ ๋˜ ๋‹ค๋ฅธ ์‚ฌ์šฉ ์‚ฌ๋ก€๋Š” ์ƒ๋ฌผํ•™์  "๋ฐ˜๋ณต ์ธก์ •" ์„ค๊ณ„๋กœ, ํ”ผํ—˜์ž ๋˜๋Š” ๋‹ค๋ฅธ ์ƒ์œ„ ์ƒ๋ฌผํ•™์  ๋‹จ์œ„๋‹น ์—ฌ๋Ÿฌ ์ƒ˜ํ”Œ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ์ƒ๋ฌผํ•™์˜ ๋งŽ์€ ์‹ค์ œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—๋Š” ๊ณ„๊ธ‰ ๋ถˆ๊ท ํ˜•์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ์ƒ˜ํ”Œ ๊ทธ๋ฃน์—๋Š” ๋™์ผํ•œ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ทธ๋ฃน์„ ๊ณ„์ธตํ™”ํ•˜๊ณ  ํ•จ๊ป˜ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

์„ค๋ช…

ํ˜„์žฌ sklearn์—๋Š” ๊ณ„์ธตํ™”๋œ ๊ทธ๋ฃน kfold ๊ธฐ๋Šฅ์ด ์—†์Šต๋‹ˆ๋‹ค. ๊ณ„์ธตํ™”๋ฅผ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ๊ทธ๋ฃน kfold๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋‘˜ ๋‹ค ์žˆ์œผ๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๊ฐ€ ๊ทธ๊ฒƒ์„ ๊ฐ–๊ธฐ๋กœ ๊ฒฐ์ •ํ•œ๋‹ค๋ฉด ๊ทธ๊ฒƒ์„ ๊ตฌํ˜„ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค.

์•ˆ๋…•ํ•˜์„ธ์š”, ์˜ํ•™ ML์— ๊ฝค ์œ ์šฉํ•  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์ด๋ฏธ ๊ตฌํ˜„๋˜์–ด ์žˆ์Šต๋‹ˆ๊นŒ?

@amueller ์‚ฌ๋žŒ๋“ค์ด ์ด๊ฒƒ์— ๊ด€์‹ฌ์ด ์žˆ๋Š” ๊ฒƒ์„ ๊ฐ์•ˆํ•  ๋•Œ ์ด๊ฒƒ์„ ๊ตฌํ˜„ํ•ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜์‹ญ๋‹ˆ๊นŒ?

์ €๋„ ๋งค์šฐ ๊ด€์‹ฌ์ด ๋งŽ์Šต๋‹ˆ๋‹ค... ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ฐ˜๋ณต ์ธก์ •๊ฐ’์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋ถ„๊ด‘ํ•™์—์„œ ์ •๋ง ์œ ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ต์ฐจ ๊ฒ€์ฆ ์ค‘์— ๋™์ผํ•œ ์ ‘๊ธฐ ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๋ถ„๋ฅ˜ํ•˜๋ ค๋Š” ๋ถˆ๊ท ํ˜• ํด๋ž˜์Šค๊ฐ€ ์—ฌ๋Ÿฌ ๊ฐœ ์žˆ๋Š” ๊ฒฝ์šฐ ๊ณ„์ธตํ™” ๊ธฐ๋Šฅ๋„ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋‚˜๋„ ํ•œํ‘œ! ๊ฐœ๋ฐœ์— ์ฐธ์—ฌํ•˜๊ธฐ์—๋Š” ๋ฏธํกํ•˜์ง€๋งŒ ์ฐธ์—ฌํ•˜์‹ค ๋ถ„๋“ค์„ ์œ„ํ•ด ์•ˆ์‹ฌํ•˜๊ณ  ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์–ด์„œ ์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค :-)
๋ชจ๋“  ํŒ€์„ ์œ„ํ•ด ์—„์ง€์†๊ฐ€๋ฝ์„ ์น˜์ผœ์„ธ์›๋‹ˆ๋‹ค. ๊ฐ์‚ฌ!

StratifiedGroupKFold ์— ๋Œ€ํ•œ ์ž‘์—…์ด ์‹œ๋„๋˜์—ˆ์œผ๋ฏ€๋กœ ์ด ์Šค๋ ˆ๋“œ์—์„œ ์ฐธ์กฐ๋œ ๋ฌธ์ œ ๋ฐ PR์„ ์‚ดํŽด๋ณด์‹ญ์‹œ์˜ค. ๋‚˜๋Š” ์ด๋ฏธ ํ…Œ์ŠคํŠธ๊ฐ€ ํ•„์š”ํ•œ StratifiedGroupShuffleSplit #15239๋ฅผ ์ˆ˜ํ–‰ํ–ˆ์ง€๋งŒ ์ด๋ฏธ ๋‚ด ์ž‘์—…์— ๊ฝค ๋งŽ์ด ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๊ฐ€ ๊ตฌํ˜„ํ•ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜์ง€๋งŒ, ์šฐ๋ฆฌ๊ฐ€ ์‹ค์ œ๋กœ ์›ํ•˜๋Š” ๊ฒƒ์ด ๋ฌด์—‡์ธ์ง€ ์•„์ง ๋ชจ๋ฅธ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. @hermidalc ๋Š” ๋™์ผํ•œ ๊ทธ๋ฃน์˜ ๊ตฌ์„ฑ์›์ด ๋™์ผํ•œ ํด๋ž˜์Šค์— ์žˆ์–ด์•ผ ํ•œ๋‹ค๋Š” ์ œํ•œ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ๊ฒฝ์šฐ๋Š” ์•„๋‹ˆ์ž–์•„์š”?

๊ด€์‹ฌ ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค์ด ์ž์‹ ์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์™€ ์ด๋กœ๋ถ€ํ„ฐ ์ง„์ • ์›ํ•˜๋Š” ๊ฒƒ์ด ๋ฌด์—‡์ธ์ง€ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

#15239 #14524 ๋ฐ #9413์ด ์žˆ๋Š”๋ฐ ๋ชจ๋‘ ๋‹ค๋ฅธ ์˜๋ฏธ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๊ฒƒ์œผ๋กœ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค.

@amueller ๋Š” ์ „์ ์œผ๋กœ ๋™์˜ํ•ฉ๋‹ˆ๋‹ค. ์ €๋Š” ์˜ค๋Š˜ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์—ฌ๋Ÿฌ ๋ฒ„์ „(#15239 #14524 ๋ฐ #9413) ์‚ฌ์ด์—์„œ ๋ญ”๊ฐ€๋ฅผ ์ฐพ๋Š” ๋ฐ ๋ช‡ ์‹œ๊ฐ„์„ ๋ณด๋ƒˆ์ง€๋งŒ ์ด ์ค‘ ์–ด๋Š ๊ฒƒ์ด ์ œ ํ•„์š”์— ๋งž๋Š”์ง€ ์ •๋ง ์ดํ•ดํ•  ์ˆ˜ ์—†์—ˆ์Šต๋‹ˆ๋‹ค. ๋„์›€์ด ๋  ์ˆ˜ ์žˆ๋‹ค๋ฉด ์—ฌ๊ธฐ ๋‚ด ์‚ฌ์šฉ ์‚ฌ๋ก€๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
1000๊ฐœ์˜ ์ƒ˜ํ”Œ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ์ƒ˜ํ”Œ์€ NIR ๋ถ„๊ด‘๊ณ„๋กœ 3๋ฒˆ ์ธก์ •๋˜์—ˆ์œผ๋ฏ€๋กœ ๊ฐ ์ƒ˜ํ”Œ์—๋Š” 3๋ฒˆ์˜ ๋ณต์ œ๊ฐ€ ํฌํ•จ๋˜์–ด ๊ณ„์† ํ•จ๊ป˜ ์œ ์ง€ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค...
์ด 1000๊ฐœ์˜ ์ƒ˜ํ”Œ์€ ๊ฐ๊ฐ ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋งค์šฐ ๋‹ค๋ฅธ 6๊ฐ€์ง€ ํด๋ž˜์Šค์— ์†ํ•ฉ๋‹ˆ๋‹ค.
ํด๋ž˜์Šค 1: 400๊ฐœ ์ƒ˜ํ”Œ
ํด๋ž˜์Šค 2: 300๊ฐœ ์ƒ˜ํ”Œ
ํด๋ž˜์Šค 3: 100๊ฐœ ์ƒ˜ํ”Œ
ํด๋ž˜์Šค 4: 100๊ฐœ ์ƒ˜ํ”Œ
ํด๋ž˜์Šค 5: 70๊ฐœ ์ƒ˜ํ”Œ
ํด๋ž˜์Šค 6: 30๊ฐœ ์ƒ˜ํ”Œ
๊ฐ ํด๋ž˜์Šค์— ๋Œ€ํ•œ ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ๋งŒ๋“ค๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํด๋ž˜์Šค 1 ๋Œ€ ๋‹ค๋ฅธ ๋ชจ๋“  ํด๋ž˜์Šค, ํด๋ž˜์Šค 2 ๋Œ€ ๋‹ค๋ฅธ ๋ชจ๋“  ํด๋ž˜์Šค ๋“ฑ์ž…๋‹ˆ๋‹ค.
๋‚ด ๋ถ„๋ฅ˜๊ธฐ ๊ฐ๊ฐ์˜ ์ •ํ™•๋„๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋ ค๋ฉด ๊ฐ ์ ‘๊ธฐ์— ํ‘œํ˜„๋œ 6๊ฐœ ํด๋ž˜์Šค์˜ ์ƒ˜ํ”Œ์„ ๊ฐ–๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ๋‚ด ํด๋ž˜์Šค๊ฐ€ ๊ทธ๋ ‡๊ฒŒ ๋‹ค๋ฅด์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ํ•ญ์ƒ 6๊ฐœ ํด๋ž˜์Šค๊ฐ€ ํ‘œํ˜„๋˜๋„๋ก ์ •ํ™•ํ•œ ๊ฒฝ๊ณ„๋ฅผ ๋งŒ๋“œ๋Š” ๋ฐ ์‹ค์ œ๋กœ ๋„์›€์ด ๋˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๊ฐ ์ ‘๊ธฐ์—์„œ.

์ด๊ฒƒ์ด ๋‚ด๊ฐ€ ๊ณ„์ธตํ™”๋œ(ํ•ญ์ƒ ๊ฐ ํด๋“œ์— 6๊ฐœ์˜ ํด๋ž˜์Šค๊ฐ€ ํ‘œ์‹œ๋จ) ๊ทธ๋ฃน(๊ฐ ์ƒ˜ํ”Œ์˜ 3๊ฐœ์˜ ๋ฐ˜๋ณต ์ธก์ •๊ฐ’์„ ํ•ญ์ƒ ํ•จ๊ป˜ ์œ ์ง€) kfold๊ฐ€ ๋‚ด๊ฐ€ ์—ฌ๊ธฐ์„œ ์ฐพ๊ณ  ์žˆ๋Š” ๊ฒƒ๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•˜๋‹ค๊ณ  ๋ฏฟ๋Š” ์ด์œ ์ž…๋‹ˆ๋‹ค.
์–ด๋–ค ์˜๊ฒฌ์ด ์žˆ์Šต๋‹ˆ๊นŒ?

๋‚ด ์‚ฌ์šฉ ์‚ฌ๋ก€์™€ ๋‚ด๊ฐ€ StratifiedGroupShuffleSplit ๋ฅผ ์“ด ์ด์œ ๋Š” https://en.wikipedia.org/wiki/Repeated_measures_design ๋ฐ˜๋ณต ์ธก์ • ๋””์ž์ธ์„ ์ง€์›ํ•˜๊ธฐ ์œ„ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‚ด ์‚ฌ์šฉ ์‚ฌ๋ก€์—์„œ ๋™์ผํ•œ ๊ทธ๋ฃน์˜ ๊ตฌ์„ฑ์›์€ ๋™์ผํ•œ ํด๋ž˜์Šค์— ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

@fcoppey ์—ฌ๋Ÿฌ๋ถ„, ๊ทธ๋ฃน ๋‚ด์˜ ์ƒ˜ํ”Œ์€ ํ•ญ์ƒ ๊ฐ™์€ ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์ฃ ?

@hermidalc ๋‚˜๋Š” ์šฉ์–ด์— ์ต์ˆ™ํ•˜์ง€ ์•Š์ง€๋งŒ wikipedia์—์„œ "๋ฐ˜๋ณต ์ธก์ • ๋””์ž์ธ"์€ "ํฌ๋กœ์Šค ์˜ค๋ฒ„ ์‹œํ—˜์—๋Š” ๋ฐ˜๋ณต ์ธก์ • ๋””์ž์ธ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ํ™˜์ž๋Š” ๋‘ ๊ฐ€์ง€ ์ด์ƒ์˜ ์ผ๋ จ์˜ ์น˜๋ฃŒ์— ๋ฐฐ์ •๋˜๋ฉฐ, ๊ทธ ์ค‘ ํ•˜๋‚˜๋Š” ํ‘œ์ค€ ์น˜๋ฃŒ ๋˜๋Š” ์œ„์•ฝ์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
์ด๊ฒƒ์„ ML ์„ค์ •๊ณผ ๊ด€๋ จํ•˜์—ฌ ์ธก์ •์„ ํ†ตํ•ด ๊ฐœ์ธ์ด ๋ฐฉ๊ธˆ ์น˜๋ฃŒ๋ฅผ ๋ฐ›์•˜๋Š”์ง€ ๋˜๋Š” ์œ„์•ฝ์„ ๋ฐ›์•˜๋Š”์ง€ ์˜ˆ์ธกํ•˜๊ฑฐ๋‚˜ ํ•ด๋‹น ์น˜๋ฃŒ์— ๋”ฐ๋ฅธ ๊ฒฐ๊ณผ๋ฅผ ์˜ˆ์ธกํ•˜๋ ค๊ณ  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์–ด๋Š ์ชฝ์ด๋“  ๊ฐ™์€ ๊ฐœ์ธ์˜ ํด๋ž˜์Šค๊ฐ€ ๋ฐ”๋€” ์ˆ˜ ์žˆ์ง€ ์•Š์Šต๋‹ˆ๊นŒ?

์ด๋ฆ„์— ๊ด€๊ณ„์—†์ด ๋‘ ์‚ฌ๋žŒ ๋ชจ๋‘ ๋™์ผํ•œ ์‚ฌ์šฉ ์‚ฌ๋ก€๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋“ค๋ฆฌ์ง€๋งŒ ์ €๋Š” ๊ต์ฐจ ์—ฐ๊ตฌ์—์„œ ์„ค๋ช…ํ•œ ๊ฒƒ๊ณผ ์œ ์‚ฌํ•œ ์‚ฌ๋ก€์— ๋Œ€ํ•ด ์ƒ๊ฐํ•˜๊ณ  ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์•„๋‹ˆ๋ฉด ์กฐ๊ธˆ ๋” ๊ฐ„๋‹จํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹œ๊ฐ„์ด ์ง€๋‚˜๋ฉด์„œ ํ™˜์ž๊ฐ€ ์•„ํ”„๊ฒŒ ํ•˜๊ฑฐ๋‚˜(๋˜๋Š” ๋‚ซ๊ฒŒ) ํ™˜์ž์˜ ๊ฒฐ๊ณผ๊ฐ€ ๋ฐ”๋€” ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‹ค์ œ๋กœ ๋งํฌํ•œ ์œ„ํ‚คํ”ผ๋””์•„ ๊ธฐ์‚ฌ์—๋Š” "์ข…๋‹จ์  ๋ถ„์„ - ๋ฐ˜๋ณต ์ธก์ • ์„ค๊ณ„๋ฅผ ํ†ตํ•ด ์—ฐ๊ตฌ์›์ด ์žฅ๋‹จ๊ธฐ ์ƒํ™ฉ ๋ชจ๋‘์—์„œ ์ฐธ๊ฐ€์ž๊ฐ€ ์‹œ๊ฐ„์ด ์ง€๋‚จ์— ๋”ฐ๋ผ ์–ด๋–ป๊ฒŒ ๋ณ€ํ™”ํ•˜๋Š”์ง€ ๋ชจ๋‹ˆํ„ฐ๋งํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋ช…์‹œ์ ์œผ๋กœ ๋ช…์‹œ๋˜์–ด ์žˆ์œผ๋ฏ€๋กœ ํด๋ž˜์Šค ๋ณ€๊ฒฝ์ด ํฌํ•จ๋˜์–ด ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.
์ธก์ •์ด ๋™์ผํ•œ ์กฐ๊ฑด์—์„œ ์ˆ˜ํ–‰๋œ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•˜๋Š” ๋‹ค๋ฅธ ๋‹จ์–ด๊ฐ€ ์žˆ๋‹ค๋ฉด ๊ทธ ๋‹จ์–ด๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ?

@amueller ์˜ˆ, ๋งž์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ์ด ์‚ฌ์šฉ ์‚ฌ๋ก€๊ฐ€ ์•„๋‹Œ ์ด ๋””์ž์ธ์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์—์„œ ๋งํ•˜๊ณ ์ž ํ–ˆ๋˜ ๋ถ€๋ถ„์„ ์œ„์—์„œ ์ž˜๋ชป ์ผ๋‹ค๋Š” ๊ฒƒ์„ ๊นจ๋‹ฌ์•˜์Šต๋‹ˆ๋‹ค.

๋ฐ˜๋ณต ์ธก์ • ์„ค๊ณ„์—๋Š” ๋งค์šฐ ์ •๊ตํ•œ ์œ ํ˜•์ด ๋งŽ์ด ์žˆ์„ ์ˆ˜ ์žˆ์ง€๋งŒ ๋‘ ์œ ํ˜•์—์„œ ๊ทธ๋ฃน ๋‚ด StratifiedGroupShuffleSplit ๊ฐ€ ํ•„์š”ํ–ˆ์Šต๋‹ˆ๋‹ค. ๋™์ผํ•œ ํด๋ž˜์Šค ์ œํ•œ์ด ์œ ์ง€๋ฉ๋‹ˆ๋‹ค(์น˜๋ฃŒ ๋ฐ˜์‘์„ ์˜ˆ์ธกํ•  ๋•Œ ์น˜๋ฃŒ ์ „ํ›„์˜ ์„ธ๋กœ ์ƒ˜ํ”Œ๋ง, ๋‹ค์ค‘ ์ „์ฒ˜๋ฆฌ ์น˜๋ฃŒ ๋ฐ˜์‘์„ ์˜ˆ์ธกํ•  ๋•Œ ๋‹ค๋ฅธ ์‹ ์ฒด ์œ„์น˜์—์„œ ํ”ผํ—˜์ž๋‹น ์ƒ˜ํ”Œ).

์ฆ‰์‹œ ์ž‘๋™ํ•˜๋Š” ๋ฌด์–ธ๊ฐ€๊ฐ€ ํ•„์š”ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์ด ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๊ณต๊ฐœํ•˜๊ณ  sklearn์—์„œ ๋ฌด์–ธ๊ฐ€๋ฅผ ์‹œ์ž‘ํ•˜๊ณ  ์‹ถ์—ˆ์Šต๋‹ˆ๋‹ค. ๊ฒŒ๋‹ค๊ฐ€ ๋‚ด๊ฐ€ ์ž˜๋ชป ์•Œ๊ณ  ์žˆ์ง€ ์•Š๋‹ค๋ฉด ๊ทธ๋ฃน ํด๋ž˜์Šค ๋ ˆ์ด๋ธ” ๋‚ด์—์„œ ๊ณ„์ธตํ™” ๋…ผ๋ฆฌ๋ฅผ ์„ค๊ณ„ํ•˜๋Š” ๊ฒƒ์ด ๋” ๋ณต์žกํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

@amueller ํ•ญ์ƒ ๊ทธ๋ ‡์Šต๋‹ˆ๋‹ค. ๊ทธ๋“ค์€ ์˜ˆ์ธก์— ์žฅ์น˜์˜ ๋‚ด๋ณ€์„ฑ์„ ํฌํ•จํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ์ธก์ •์˜ ๋ณต์ œ์ž…๋‹ˆ๋‹ค.

@hermidalc ๋„ค, ์ด ๊ฒฝ์šฐ๊ฐ€ ํ›จ์”ฌ ์‰ฝ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ์š”๊ตฌ ์‚ฌํ•ญ์ด๋ผ๋ฉด ๊ธฐ๊บผ์ด ์ถ”๊ฐ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ด๋ฆ„์—์„œ ๊ทธ๊ฒƒ์ด ํ•˜๋Š” ์ผ์„ ์–ด๋Š ์ •๋„ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์•ผ ํ•˜๊ณ  ์ด ๋‘ ๋ฒ„์ „์ด ๊ฐ™์€ ํด๋ž˜์Šค์— ์žˆ์–ด์•ผ ํ•˜๋Š”์ง€์— ๋Œ€ํ•ด ์ƒ๊ฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

StratifiedKFold ๊ฐ€ ์ด๊ฒƒ์„ ํ•˜๋„๋ก ๋งŒ๋“œ๋Š” ๊ฒƒ์€ ์•„์ฃผ ์‰ฌ์šธ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‘ ๊ฐ€์ง€ ์˜ต์…˜์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ์ ‘๊ธฐ์— ๋น„์Šทํ•œ ์ˆ˜์˜ ์ƒ˜ํ”Œ์ด ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๊ฑฐ๋‚˜ ๊ฐ ์ ‘๊ธฐ์— ๋น„์Šทํ•œ ์ˆ˜์˜ ๊ทธ๋ฃน์ด ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
๋‘ ๋ฒˆ์งธ ๊ฒƒ์€ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(๊ฐ ๊ทธ๋ฃน์ด ๋‹จ์ผ ํฌ์ธํŠธ์ธ ๊ฒƒ์ฒ˜๋Ÿผ ๊ฐ€์žฅํ•˜๊ณ  StratifiedKFold ์— ์ „๋‹ฌ). ๊ทธ๊ฒƒ์ด ๋‹น์‹ ์ด PR์—์„œ ํ•˜๋Š” ์ผ์ž…๋‹ˆ๋‹ค.

GroupKFold ๋‚˜๋Š” ๊ฐ€์žฅ ์ž‘์€ ์ ‘๊ธฐ๋ฅผ ๋จผ์ € ์ถ”๊ฐ€ํ•˜์—ฌ ๋‘ ๊ฐ€์ง€๋ฅผ ๊ฒฝํ—˜์ ์œผ๋กœ ๊ตํ™˜ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ๊ณ„์ธตํ™”๋œ ๊ฒฝ์šฐ๋กœ ์–ด๋–ป๊ฒŒ ๋ณ€ํ™˜๋ ์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ท€ํ•˜์˜ ์ ‘๊ทผ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๊ฒŒ ๋˜์–ด ๊ธฐ์ฉ๋‹ˆ๋‹ค.

๋™์ผํ•œ PR์— GroupStratifiedKFold๋„ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๊นŒ? ์•„๋‹ˆ๋ฉด ๋‚˜์ค‘์„ ์œ„ํ•ด ๋‚จ๊ฒจ๋‘˜๊นŒ์š”?
๋‹ค๋ฅธ PR์€ ์•ฝ๊ฐ„ ๋‹ค๋ฅธ ๋ชฉํ‘œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋ˆ„๊ตฐ๊ฐ€ ๋‹ค๋ฅธ ์‚ฌ์šฉ ์‚ฌ๋ก€๊ฐ€ ๋ฌด์—‡์ธ์ง€ ์“ธ ์ˆ˜ ์žˆ๋‹ค๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค(์ง€๊ธˆ์€ ์‹œ๊ฐ„์ด ์—†์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค).

๋ชจ๋“  ์ƒ˜ํ”Œ์— ๋™์ผํ•œ ํด๋ž˜์Šค๊ฐ€ ์žˆ๋Š” ๊ทธ๋ฃน ์ œ์•ฝ ์กฐ๊ฑด์„ ๋ณ„๋„๋กœ ์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด +1์ž…๋‹ˆ๋‹ค.

@hermidalc ๋„ค, ์ด ๊ฒฝ์šฐ๊ฐ€ ํ›จ์”ฌ ์‰ฝ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ์š”๊ตฌ ์‚ฌํ•ญ์ด๋ผ๋ฉด ๊ธฐ๊บผ์ด ์ถ”๊ฐ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ด๋ฆ„์—์„œ ๊ทธ๊ฒƒ์ด ํ•˜๋Š” ์ผ์„ ์–ด๋Š ์ •๋„ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์•ผ ํ•˜๊ณ  ์ด ๋‘ ๋ฒ„์ „์ด ๊ฐ™์€ ํด๋ž˜์Šค์— ์žˆ์–ด์•ผ ํ•˜๋Š”์ง€์— ๋Œ€ํ•ด ์ƒ๊ฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋‚˜๋Š” ์ด๊ฒƒ์„ ์™„์ „ํžˆ ์ดํ•ดํ•˜์ง€ ๋ชปํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. StratifiedGroupShuffleSplit ๋ฐ StratifiedGroupKFold ๊ฐ ๊ทธ๋ฃน์˜ ๊ตฌ์„ฑ์›์ด ๋‹ค๋ฅธ ํด๋ž˜์Šค์— ์†ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ ์‚ฌ์šฉ์ž๊ฐ€ ๋ชจ๋“  ๊ทธ๋ฃน ๊ตฌ์„ฑ์›์„ ๋‹ค์Œ์œผ๋กœ ์ง€์ •ํ•  ๋•Œ ์ •ํ™•ํžˆ ๋™์ผํ•œ ๋ถ„ํ•  ๋™์ž‘์„ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ™์€ ํด๋ž˜์Šค์˜. ๋‚˜์ค‘์— ๋‚ด๋ถ€๋ฅผ ๊ฐœ์„ ํ•  ์ˆ˜ ์žˆ๊ณ  ๊ธฐ์กด ๋™์ž‘์ด ๋™์ผํ•˜๊ฒŒ ๋˜๋Š” ๊ฒฝ์šฐ๋Š” ์–ธ์ œ์ž…๋‹ˆ๊นŒ?

๋‘ ๋ฒˆ์งธ ๊ฒƒ์€ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(๊ฐ ๊ทธ๋ฃน์ด ๋‹จ์ผ ํฌ์ธํŠธ์ธ ๊ฒƒ์ฒ˜๋Ÿผ ๊ฐ€์žฅํ•˜๊ณ  StratifiedKFold ์— ์ „๋‹ฌ). ๊ทธ๊ฒƒ์ด ๋‹น์‹ ์ด PR์—์„œ ํ•˜๋Š” ์ผ์ž…๋‹ˆ๋‹ค.

GroupKFold ๋‚˜๋Š” ๊ฐ€์žฅ ์ž‘์€ ์ ‘๊ธฐ๋ฅผ ๋จผ์ € ์ถ”๊ฐ€ํ•˜์—ฌ ๋‘ ๊ฐ€์ง€๋ฅผ ๊ฒฝํ—˜์ ์œผ๋กœ ๊ตํ™˜ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ๊ณ„์ธตํ™”๋œ ๊ฒฝ์šฐ๋กœ ์–ด๋–ป๊ฒŒ ๋ณ€ํ™˜๋ ์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ท€ํ•˜์˜ ์ ‘๊ทผ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๊ฒŒ ๋˜์–ด ๊ธฐ์ฉ๋‹ˆ๋‹ค.

๋™์ผํ•œ PR์— GroupStratifiedKFold๋„ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๊นŒ? ์•„๋‹ˆ๋ฉด ๋‚˜์ค‘์„ ์œ„ํ•ด ๋‚จ๊ฒจ๋‘˜๊นŒ์š”?
๋‹ค๋ฅธ PR์€ ์•ฝ๊ฐ„ ๋‹ค๋ฅธ ๋ชฉํ‘œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋ˆ„๊ตฐ๊ฐ€ ๋‹ค๋ฅธ ์‚ฌ์šฉ ์‚ฌ๋ก€๊ฐ€ ๋ฌด์—‡์ธ์ง€ ์“ธ ์ˆ˜ ์žˆ๋‹ค๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค(์ง€๊ธˆ์€ ์‹œ๊ฐ„์ด ์—†์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค).

๋‚ด๊ฐ€ ์‚ฌ์šฉํ•œ "๊ฐ ๊ทธ๋ฃน ๋‹จ์ผ ์ƒ˜ํ”Œ" ์ ‘๊ทผ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜์—ฌ StatifiedGroupKFold ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

๊ด€์‹ฌ ์žˆ๋Š” ์‚ฌ๋žŒ๋“ค์ด ์ž์‹ ์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์™€ ์ด๋กœ๋ถ€ํ„ฐ ์ง„์ • ์›ํ•˜๋Š” ๊ฒƒ์ด ๋ฌด์—‡์ธ์ง€ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ธก์ •์„ ๋ฐ˜๋ณตํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ ์˜ํ•™ ๋ฐ ์ƒ๋ฌผํ•™์—์„œ ๋งค์šฐ ์ผ๋ฐ˜์ ์ธ ์‚ฌ์šฉ ์‚ฌ๋ก€์ž…๋‹ˆ๋‹ค.
์˜ˆ: MR ์ด๋ฏธ์ง€์—์„œ ์•Œ์ธ ํ•˜์ด๋จธ๋ณ‘(AD) ๋Œ€ ๊ฑด๊ฐ•ํ•œ ๋Œ€์กฐ๊ตฐ๊ณผ ๊ฐ™์€ ์งˆ๋ณ‘์„ ๋ถ„๋ฅ˜ํ•˜๋ ค๊ณ  ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. ๋™์ผํ•œ ์ฃผ์ œ์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ์Šค์บ”์ด ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์ถ”์  ์„ธ์…˜ ๋˜๋Š” ์ข…๋‹จ ๋ฐ์ดํ„ฐ์—์„œ). ์ด 1000๋ช…์˜ ํ”ผํ—˜์ž๊ฐ€ ์žˆ๊ณ  ๊ทธ ์ค‘ 200๋ช…์ด AD(๋ถˆ๊ท ํ˜• ํด๋ž˜์Šค)๋กœ ์ง„๋‹จ๋œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋Œ€๋ถ€๋ถ„์˜ ํ”ผ์‚ฌ์ฒด๋Š” ํ•œ ๋ฒˆ ์Šค์บ”ํ•˜์ง€๋งŒ ์ผ๋ถ€๋Š” 2~3๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ํ›ˆ๋ จ/ํ…Œ์ŠคํŠธํ•  ๋•Œ ๋ฐ์ดํ„ฐ ๋ˆ„์ถœ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ์ฃผ์ œ์˜ ์ด๋ฏธ์ง€๊ฐ€ ํ•ญ์ƒ ๋™์ผํ•œ ์ ‘ํž˜์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฅผ ์œ„ํ•ด StratifiedGroupKFold๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ข‹์Šต๋‹ˆ๋‹ค. ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜•์„ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ๊ณ„์ธตํ™”ํ•˜์ง€๋งŒ ์ฃผ์ œ๊ฐ€ ๋‹ค๋ฅธ ํด๋“œ์— ๋‚˜ํƒ€๋‚˜์ง€ ์•Š์•„์•ผ ํ•œ๋‹ค๋Š” ๊ทธ๋ฃน ์ œ์•ฝ ์กฐ๊ฑด์ด ์žˆ์Šต๋‹ˆ๋‹ค.
์ฃผ์˜: ๋ฐ˜๋ณต ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

์•„๋ž˜๋Š” kaggle-kernel ์—์„œ ์˜๊ฐ์„ ๋ฐ›์€ ๊ตฌํ˜„ ์˜ˆ์ž…๋‹ˆ๋‹ค.

import numpy as np
from collections import Counter, defaultdict
from sklearn.utils import check_random_state

class RepeatedStratifiedGroupKFold():

    def __init__(self, n_splits=5, n_repeats=1, random_state=None):
        self.n_splits = n_splits
        self.n_repeats = n_repeats
        self.random_state = random_state

    # Implementation based on this kaggle kernel:
    #    https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
    def split(self, X, y=None, groups=None):
        k = self.n_splits
        def eval_y_counts_per_fold(y_counts, fold):
            y_counts_per_fold[fold] += y_counts
            std_per_label = []
            for label in range(labels_num):
                label_std = np.std(
                    [y_counts_per_fold[i][label] / y_distr[label] for i in range(k)]
                )
                std_per_label.append(label_std)
            y_counts_per_fold[fold] -= y_counts
            return np.mean(std_per_label)

        rnd = check_random_state(self.random_state)
        for repeat in range(self.n_repeats):
            labels_num = np.max(y) + 1
            y_counts_per_group = defaultdict(lambda: np.zeros(labels_num))
            y_distr = Counter()
            for label, g in zip(y, groups):
                y_counts_per_group[g][label] += 1
                y_distr[label] += 1

            y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num))
            groups_per_fold = defaultdict(set)

            groups_and_y_counts = list(y_counts_per_group.items())
            rnd.shuffle(groups_and_y_counts)

            for g, y_counts in sorted(groups_and_y_counts, key=lambda x: -np.std(x[1])):
                best_fold = None
                min_eval = None
                for i in range(k):
                    fold_eval = eval_y_counts_per_fold(y_counts, i)
                    if min_eval is None or fold_eval < min_eval:
                        min_eval = fold_eval
                        best_fold = i
                y_counts_per_fold[best_fold] += y_counts
                groups_per_fold[best_fold].add(g)

            all_groups = set(groups)
            for i in range(k):
                train_groups = all_groups - groups_per_fold[i]
                test_groups = groups_per_fold[i]

                train_indices = [i for i, g in enumerate(groups) if g in train_groups]
                test_indices = [i for i, g in enumerate(groups) if g in test_groups]

                yield train_indices, test_indices

RepeatedStratifiedKFold (๊ฐ™์€ ๊ทธ๋ฃน์˜ ์ƒ˜ํ”Œ์ด ๋‘ ํด๋“œ์— ๋‚˜ํƒ€๋‚  ์ˆ˜ ์žˆ์Œ)์™€ RepeatedStratifiedGroupKFold ๋น„๊ต:

import matplotlib.pyplot as plt
from sklearn import model_selection

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ax.scatter(range(len(indices)), [ii + .5] * len(indices),
                   c=indices, marker='_', lw=lw, cmap=plt.cm.coolwarm,
                   vmin=-.2, vmax=1.2)

    ax.scatter(range(len(X)), [ii + 1.5] * len(X), c=y, marker='_',
               lw=lw, cmap=plt.cm.Paired)
    ax.scatter(range(len(X)), [ii + 2.5] * len(X), c=group, marker='_',
               lw=lw, cmap=plt.cm.tab20c)

    yticklabels = list(range(n_splits)) + ['class', 'group']
    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
           xlabel='Sample index', ylabel="CV iteration",
           ylim=[n_splits+2.2, -.2], xlim=[0, 100])
    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)


# demonstration
np.random.seed(1338)
n_splits = 4
n_repeats=5


# Generate the class/group data
n_points = 100
X = np.random.randn(100, 10)

percentiles_classes = [.4, .6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

# Evenly spaced groups
g = np.hstack([[ii] * 5 for ii in range(20)])


fig, ax = plt.subplots(1,2, figsize=(14,4))

cv_nogrp = model_selection.RepeatedStratifiedKFold(n_splits=n_splits,
                                                   n_repeats=n_repeats,
                                                   random_state=1338)
cv_grp = RepeatedStratifiedGroupKFold(n_splits=n_splits,
                                      n_repeats=n_repeats,
                                      random_state=1338)

plot_cv_indices(cv_nogrp, X, y, g, ax[0], n_splits * n_repeats)
plot_cv_indices(cv_grp, X, y, g, ax[1], n_splits * n_repeats)

plt.show()

RepeatedStratifiedGroupKFold_demo

stratifiedGroupKfold์˜ ๊ฒฝ์šฐ +1์ž…๋‹ˆ๋‹ค. samrt ์‹œ๊ณ„์—์„œ ์„ผ์„œ๋ฅผ ๊ฐ€์ ธ์™€ ๋…ธ์ธ์˜ ๋‚™์ƒ์„ ๊ฐ์ง€ํ•˜๋ ค๊ณ ํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ๊ฐ€์„ ๋ฐ์ดํ„ฐ๊ฐ€ ๋งŽ์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค๋ฅธ ๋“ฑ๊ธ‰์„ ๊ฐ€์ง„ ๋‹ค๋ฅธ ์‹œ๊ณ„๋กœ ์‹œ๋ฎฌ๋ ˆ์ด์…˜์„ ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ํ›ˆ๋ จํ•˜๊ธฐ ์ „์— ๋ฐ์ดํ„ฐ๋ฅผ ๋ณด๊ฐ•ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์—์„œ 9๊ฐœ์˜ ํฌ์ธํŠธ๋ฅผ ๋งŒ๋“ค๊ณ  ์ด๊ฒƒ์ด ๊ทธ๋ฃน์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฃน์ด ์„ค๋ช…๋œ ๋Œ€๋กœ ํ›ˆ๋ จ๊ณผ ํ…Œ์ŠคํŠธ๋ฅผ ๋ชจ๋‘ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

StratifiedGroupKFold๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ธฐ๋ฅผ ๋ฐ”๋ž๋‹ˆ๋‹ค. ์ €๋Š” ๊ธˆ์œต ์œ„๊ธฐ๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋ณด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ์œ„๊ธฐ ์ „, ํ›„, ๋™์•ˆ์€ ์ž์ฒด ๊ทธ๋ฃน์ž…๋‹ˆ๋‹ค. ํ›ˆ๋ จ ๋ฐ ๊ต์ฐจ ๊ฒ€์ฆ ์ค‘์— ๊ฐ ๊ทธ๋ฃน์˜ ๊ตฌ์„ฑ์›์€ ํด๋“œ ์‚ฌ์ด์—์„œ ๋ˆ„์ถœ๋˜์–ด์„œ๋Š” ์•ˆ ๋ฉ๋‹ˆ๋‹ค.

์–ด์จŒ๋“  ๋‹ค์ค‘ ๋ ˆ์ด๋ธ” ์‹œ๋‚˜๋ฆฌ์˜ค(Multilabel_
๊ณ„์ธตํ™” ๊ทธ๋ฃนKfold)?

์ด๊ฒƒ์— ๋Œ€ํ•ด +1. ์ŠคํŒธ์— ๋Œ€ํ•œ ์‚ฌ์šฉ์ž ๊ณ„์ •์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์œผ๋ฏ€๋กœ ์‚ฌ์šฉ์ž๋ณ„๋กœ ๊ทธ๋ฃนํ™”ํ•˜๊ณ  ์‹ถ์ง€๋งŒ ์ŠคํŒธ์€ ์ƒ๋Œ€์ ์œผ๋กœ ๋ฐœ์ƒ๋ฅ ์ด ๋‚ฎ๊ธฐ ๋•Œ๋ฌธ์— ๊ณ„์ธตํ™”ํ•˜๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์—์„œ๋Š” ํ•œ ๋ฒˆ ์ŠคํŒธ์„ ๋ณด๋‚ธ ์‚ฌ์šฉ์ž๋Š” ๋ชจ๋“  ๋ฐ์ดํ„ฐ์—์„œ ์ŠคํŒธ ๋ฐœ์†ก์ž๋กœ ํ‘œ์‹œ๋˜๋ฏ€๋กœ ๊ทธ๋ฃน ๊ตฌ์„ฑ์›์€ ํ•ญ์ƒ ๋™์ผํ•œ ๋ ˆ์ด๋ธ”์„ ๊ฐ–๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

๋ฌธ์„œ๋ฅผ ๊ตฌ์„ฑํ•˜๊ธฐ ์œ„ํ•œ ๊ณ ์ „์ ์ธ ์‚ฌ์šฉ ์‚ฌ๋ก€๋ฅผ ์ œ๊ณตํ•ด ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.
@philip-iv!

StratifiedGroupShuffleSplit ์™€ ๋™์ผํ•œ PR #15239์— StratifiedGroupKFold ๊ตฌํ˜„์„ ์ถ”๊ฐ€ํ–ˆ์Šต๋‹ˆ๋‹ค.

PR์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ์ด ๋‘˜ ๋‹ค์— ๋Œ€ํ•œ ๋…ผ๋ฆฌ๋Š” https://github.com/scikit-learn/scikit-learn/issues/13621#issuecomment -557802602๋ณด๋‹ค ํ›จ์”ฌ ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค. ๊ณ ์œ ํ•œ ๊ทธ๋ฃน ์ •๋ณด๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ๊ธฐ์กด StratifiedKFold ๋ฐ StratifiedShuffleSplit ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๊ฐ ํด๋ž˜์Šค(์ƒ˜ํ”Œ์˜ ๋ฐฑ๋ถ„์œจ ์•„๋‹˜). ๊ทธ๋Ÿฌ๋‚˜ ๋‘ ๊ตฌํ˜„ ๋ชจ๋‘ ๊ฐ ๊ทธ๋ฃน์˜ ์ƒ˜ํ”Œ์ด ๋™์ผํ•œ ํด๋“œ์— ํ•จ๊ป˜ ์œ ์ง€๋˜๋Š” ํด๋“œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

https://github.com/scikit-learn/scikit-learn/issues/13621#issuecomment -557802602๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋” ์ •๊ตํ•œ ๋ฐฉ๋ฒ•์— ํˆฌํ‘œํ•˜๊ฒ ์ง€๋งŒ

๋‹ค์Œ์€ @mrunibe ์—์„œ ์ œ๊ณตํ•œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ StratifiedGroupKFold ๋ฐ RepeatedStratifiedGroupKFold ์˜ ๋ณธ๊ฒฉ์ ์ธ ๋ฒ„์ „์ž…๋‹ˆ๋‹ค. ์ด ์ฝ”๋“œ๋Š” ๋ช‡ ๊ฐ€์ง€๋ฅผ ๋”์šฑ ๋‹จ์ˆœํ™”ํ•˜๊ณ  ๋ณ€๊ฒฝํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด ํด๋ž˜์Šค๋Š” ๋˜ํ•œ ๋™์ผํ•œ ์œ ํ˜•์˜ ๋‹ค๋ฅธ sklearn CV ํด๋ž˜์Šค๊ฐ€ ์ˆ˜ํ–‰๋˜๋Š” ๋ฐฉ์‹์˜ ์„ค๊ณ„๋ฅผ ๋”ฐ๋ฆ…๋‹ˆ๋‹ค.

class StratifiedGroupKFold(_BaseKFold):
    """Stratified K-Folds iterator variant with non-overlapping groups.

    This cross-validation object is a variation of StratifiedKFold that returns
    stratified folds with non-overlapping groups. The folds are made by
    preserving the percentage of samples for each class.

    The same group will not appear in two different folds (the number of
    distinct groups has to be at least equal to the number of folds).

    The difference between GroupKFold and StratifiedGroupKFold is that
    the former attempts to create balanced folds such that the number of
    distinct groups is approximately the same in each fold, whereas
    StratifiedGroupKFold attempts to create folds which preserve the
    percentage of samples for each class.

    Read more in the :ref:`User Guide <cross_validation>`.

    Parameters
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 2.

    shuffle : bool, default=False
        Whether to shuffle each class's samples before splitting into batches.
        Note that the samples within each split will not be shuffled.

    random_state : int or RandomState instance, default=None
        When `shuffle` is True, `random_state` affects the ordering of the
        indices, which controls the randomness of each fold for each class.
        Otherwise, leave `random_state` as `None`.
        Pass an int for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import StratifiedGroupKFold
    >>> X = np.ones((17, 2))
    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
    >>> cv = StratifiedGroupKFold(n_splits=3)
    >>> for train_idxs, test_idxs in cv.split(X, y, groups):
    ...     print("TRAIN:", groups[train_idxs])
    ...     print("      ", y[train_idxs])
    ...     print(" TEST:", groups[test_idxs])
    ...     print("      ", y[test_idxs])
    TRAIN: [2 2 4 5 5 5 5 6 6 7]
           [1 1 1 0 0 0 0 0 0 0]
     TEST: [1 1 3 3 3 8 8]
           [0 0 1 1 1 0 0]
    TRAIN: [1 1 3 3 3 4 5 5 5 5 8 8]
           [0 0 1 1 1 1 0 0 0 0 0 0]
     TEST: [2 2 6 6 7]
           [1 1 0 0 0]
    TRAIN: [1 1 2 2 3 3 3 6 6 7 8 8]
           [0 0 1 1 1 1 1 0 0 0 0 0]
     TEST: [4 5 5 5 5]
           [1 0 0 0 0]

    See also
    --------
    StratifiedKFold: Takes class information into account to build folds which
        retain class distributions (for binary or multiclass classification
        tasks).

    GroupKFold: K-fold iterator variant with non-overlapping groups.
    """

    def __init__(self, n_splits=5, shuffle=False, random_state=None):
        super().__init__(n_splits=n_splits, shuffle=shuffle,
                         random_state=random_state)

    # Implementation based on this kaggle kernel:
    # https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
    def _iter_test_indices(self, X, y, groups):
        labels_num = np.max(y) + 1
        y_counts_per_group = defaultdict(lambda: np.zeros(labels_num))
        y_distr = Counter()
        for label, group in zip(y, groups):
            y_counts_per_group[group][label] += 1
            y_distr[label] += 1

        y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num))
        groups_per_fold = defaultdict(set)

        groups_and_y_counts = list(y_counts_per_group.items())
        rng = check_random_state(self.random_state)
        if self.shuffle:
            rng.shuffle(groups_and_y_counts)

        for group, y_counts in sorted(groups_and_y_counts,
                                      key=lambda x: -np.std(x[1])):
            best_fold = None
            min_eval = None
            for i in range(self.n_splits):
                y_counts_per_fold[i] += y_counts
                std_per_label = []
                for label in range(labels_num):
                    std_per_label.append(np.std(
                        [y_counts_per_fold[j][label] / y_distr[label]
                         for j in range(self.n_splits)]))
                y_counts_per_fold[i] -= y_counts
                fold_eval = np.mean(std_per_label)
                if min_eval is None or fold_eval < min_eval:
                    min_eval = fold_eval
                    best_fold = i
            y_counts_per_fold[best_fold] += y_counts
            groups_per_fold[best_fold].add(group)

        for i in range(self.n_splits):
            test_indices = [idx for idx, group in enumerate(groups)
                            if group in groups_per_fold[i]]
            yield test_indices


class RepeatedStratifiedGroupKFold(_RepeatedSplits):
    """Repeated Stratified K-Fold cross validator.

    Repeats Stratified K-Fold with non-overlapping groups n times with
    different randomization in each repetition.

    Read more in the :ref:`User Guide <cross_validation>`.

    Parameters
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 2.

    n_repeats : int, default=10
        Number of times cross-validator needs to be repeated.

    random_state : int or RandomState instance, default=None
        Controls the generation of the random states for each repetition.
        Pass an int for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import RepeatedStratifiedGroupKFold
    >>> X = np.ones((17, 2))
    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
    >>> cv = RepeatedStratifiedGroupKFold(n_splits=2, n_repeats=2,
    ...                                   random_state=36851234)
    >>> for train_index, test_index in cv.split(X, y, groups):
    ...     print("TRAIN:", groups[train_idxs])
    ...     print("      ", y[train_idxs])
    ...     print(" TEST:", groups[test_idxs])
    ...     print("      ", y[test_idxs])
    TRAIN: [2 2 4 5 5 5 5 8 8]
           [1 1 1 0 0 0 0 0 0]
     TEST: [1 1 3 3 3 6 6 7]
           [0 0 1 1 1 0 0 0]
    TRAIN: [1 1 3 3 3 6 6 7]
           [0 0 1 1 1 0 0 0]
     TEST: [2 2 4 5 5 5 5 8 8]
           [1 1 1 0 0 0 0 0 0]
    TRAIN: [3 3 3 4 7 8 8]
           [1 1 1 1 0 0 0]
     TEST: [1 1 2 2 5 5 5 5 6 6]
           [0 0 1 1 0 0 0 0 0 0]
    TRAIN: [1 1 2 2 5 5 5 5 6 6]
           [0 0 1 1 0 0 0 0 0 0]
     TEST: [3 3 3 4 7 8 8]
           [1 1 1 1 0 0 0]

    Notes
    -----
    Randomized CV splitters may return different results for each call of
    split. You can make the results identical by setting `random_state`
    to an integer.

    See also
    --------
    RepeatedStratifiedKFold: Repeats Stratified K-Fold n times.
    """

    def __init__(self, n_splits=5, n_repeats=10, random_state=None):
        super().__init__(StratifiedGroupKFold, n_splits=n_splits,
                         n_repeats=n_repeats, random_state=random_state)

@hermidalc ๋•Œ๋•Œ๋กœ ์ด๊ฒƒ์„ ๋˜๋Œ์•„๋ณด๋ฉด์„œ ์šฐ๋ฆฌ๊ฐ€ ๋ฌด์—‡์„ ํ•ด๊ฒฐํ–ˆ๋Š”์ง€ ์ƒ๋‹นํžˆ ํ˜ผ๋ž€์Šค๋Ÿฝ์Šต๋‹ˆ๋‹ค. (์•ˆํƒ€๊น๊ฒŒ๋„ ์ œ ์‹œ๊ฐ„์€ ์˜ˆ์ „ ๊ฐ™์ง€ ์•Š์Šต๋‹ˆ๋‹ค!) scikit-learn์— ํฌํ•จ์‹œํ‚ฌ ๊ฒƒ์„ ๊ถŒ์žฅํ•  ๋งŒํ•œ ์•„์ด๋””์–ด๋ฅผ ์•Œ๋ ค์ฃผ์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ?

@hermidalc ๋•Œ๋•Œ๋กœ ์ด๊ฒƒ์„ ๋˜๋Œ์•„๋ณด๋ฉด์„œ ์šฐ๋ฆฌ๊ฐ€ ๋ฌด์—‡์„ ํ•ด๊ฒฐํ–ˆ๋Š”์ง€ ์ƒ๋‹นํžˆ ํ˜ผ๋ž€์Šค๋Ÿฝ์Šต๋‹ˆ๋‹ค. (์•ˆํƒ€๊น๊ฒŒ๋„ ์ œ ์‹œ๊ฐ„์€ ์˜ˆ์ „ ๊ฐ™์ง€ ์•Š์Šต๋‹ˆ๋‹ค!) scikit-learn์— ํฌํ•จ์‹œํ‚ฌ ๊ฒƒ์„ ๊ถŒ์žฅํ•  ๋งŒํ•œ ์•„์ด๋””์–ด๋ฅผ ์•Œ๋ ค์ฃผ์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ?

#15239์—์„œ ํ–ˆ๋˜ ๊ฒƒ๋ณด๋‹ค ๋” ๋‚˜์€ ๊ตฌํ˜„์„ ํ•˜๊ณ  ์‹ถ์—ˆ์Šต๋‹ˆ๋‹ค. ํ•ด๋‹น PR์˜ ๊ตฌํ˜„์€ ์ž‘๋™ํ•˜์ง€๋งŒ ๋…ผ๋ฆฌ๋ฅผ ๊ฐ„๋‹จํ•˜๊ฒŒ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด ๊ทธ๋ฃน์„ ๊ณ„์ธตํ™”ํ•˜์ง€๋งŒ ์ด๊ฒƒ์ด ์ด์ƒ์ ์ด์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค.

๊ทธ๋ž˜์„œ ์œ„์—์„œ ๋‚ด๊ฐ€ ํ•œ ๊ฒƒ์€(jakubwasikowski์˜ @mrunibe ๋ฐ kaggle ๋•๋ถ„์—) ์ƒ˜ํ”Œ์„ ๊ณ„์ธตํ™”ํ•˜๋Š” StratifiedGroupKFold ์˜ ๋” ๋‚˜์€ ๊ตฌํ˜„์ž…๋‹ˆ๋‹ค. ๋” ๋‚˜์€ StratifiedGroupShuffleSplit ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ๋…ผ๋ฆฌ๋ฅผ ์ด์‹ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์ค€๋น„๊ฐ€ ๋ฉ๋‹ˆ๋‹ค. ์ด์ „ ๊ตฌํ˜„์„ ๋Œ€์ฒดํ•˜๊ธฐ ์œ„ํ•ด ์ƒˆ ์ฝ”๋“œ๋ฅผ #15239์— ๋„ฃ์Šต๋‹ˆ๋‹ค.

๋ฏธ์™„์„ฑ์ธ PR์— ๋Œ€ํ•ด ์‚ฌ๊ณผ๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋ฐ•์‚ฌ ํ•™์œ„๋ฅผ ๋ฐ›๋Š” ์ค‘์ด๋ฏ€๋กœ ์‹œ๊ฐ„์ด ์—†์Šต๋‹ˆ๋‹ค!

๊ตฌํ˜„์„ ์ œ๊ณต ํ•ด์ฃผ์‹  @hermidalc ์™€ @mrunibe ์—๊ฒŒ ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋˜ํ•œ ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜•์ด ์‹ฌํ•˜๊ณ  ์ฃผ์ œ๋‹น ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋งค์šฐ ๋‹ค์–‘ํ•œ ์˜๋ฃŒ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ StratifiedGroupKFold ๋ฐฉ๋ฒ•์„ ์ฐพ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. GroupKFold ์ž์ฒด์ ์œผ๋กœ ํ•˜๋‚˜์˜ ํด๋ž˜์Šค๋งŒ ํฌํ•จํ•˜๋Š” ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ํ•˜์œ„ ์ง‘ํ•ฉ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

๋” ๋‚˜์€ StratifiedGroupShuffleSplit์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ๋…ผ๋ฆฌ๋ฅผ ์ด์‹ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์ค€๋น„๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

StratifiedGroupShuffleSplit StratifiedGroupKFold ๋ณ‘ํ•ฉ์„ ํ™•์‹คํžˆ ๊ณ ๋ คํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ฏธ์™„์„ฑ์ธ PR์— ๋Œ€ํ•ด ์‚ฌ๊ณผ๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋ฐ•์‚ฌ ํ•™์œ„๋ฅผ ๋ฐ›๋Š” ์ค‘์ด๋ฏ€๋กœ ์‹œ๊ฐ„์ด ์—†์Šต๋‹ˆ๋‹ค!

์™„๋ฃŒ ์ง€์›์„ ์›ํ•˜์‹œ๋ฉด ์ €ํฌ์—๊ฒŒ ์•Œ๋ ค์ฃผ์‹ญ์‹œ์˜ค!

๊ทธ๋ฆฌ๊ณ  ๋ฐ•์‚ฌ ๊ณผ์ •์— ํ–‰์šด์„ ๋น•๋‹ˆ๋‹ค

๋‹ค์Œ์€ @mrunibe ์—์„œ ์ œ๊ณตํ•œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ StratifiedGroupKFold ๋ฐ RepeatedStratifiedGroupKFold ์˜ ๋ณธ๊ฒฉ์ ์ธ ๋ฒ„์ „์ž…๋‹ˆ๋‹ค. ์ด ์ฝ”๋“œ๋Š” ๋ช‡ ๊ฐ€์ง€๋ฅผ ๋”์šฑ ๋‹จ์ˆœํ™”ํ•˜๊ณ  ๋ณ€๊ฒฝํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด ํด๋ž˜์Šค๋Š” ๋˜ํ•œ ๋™์ผํ•œ ์œ ํ˜•์˜ ๋‹ค๋ฅธ sklearn CV ํด๋ž˜์Šค๊ฐ€ ์ˆ˜ํ–‰๋˜๋Š” ๋ฐฉ์‹์˜ ์„ค๊ณ„๋ฅผ ๋”ฐ๋ฆ…๋‹ˆ๋‹ค.

์ด๊ฒƒ์„ ์‹œ๋„ํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๊นŒ? ๋‹ค์–‘ํ•œ ์ข…์†์„ฑ ์ค‘ ์ผ๋ถ€๋ฅผ ์ž˜๋ผ๋‚ด์–ด ๋ถ™์—ฌ๋„ฃ์œผ๋ ค๊ณ  ํ–ˆ์ง€๋งŒ ๋์ด ์—†์—ˆ์Šต๋‹ˆ๋‹ค. ๋‚ด ํ”„๋กœ์ ํŠธ์—์„œ ์ด ์ˆ˜์—…์„ ์‹œ๋„ํ•ด ๋ณด๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ์ง€๊ธˆ ๊ฐ€๋Šฅํ•œ ๋ฐฉ๋ฒ•์ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

@hermidalc ๋ฐ•์‚ฌ ๊ณผ์ •์ด ์„ฑ๊ณต์ ์ด๊ธฐ๋ฅผ ๋ฐ”๋ž๋‹ˆ๋‹ค!
Geosciences์—์„œ ์ œ ๋ฐ•์‚ฌ ํ•™์œ„ ์ž‘์—…์—๋Š” ๊ทธ๋ฃน ์ œ์–ด์™€ ํ•จ๊ป˜ ์ด ๊ณ„์ธตํ™” ๊ธฐ๋Šฅ์ด ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด ๊ตฌํ˜„์ด ์™„๋ฃŒ๋˜๋Š” ๊ฒƒ์„ ๋ณด๊ธฐ๋ฅผ ๊ณ ๋Œ€ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ด ํ”„๋กœ์ ํŠธ์—์„œ ์ˆ˜๋™์œผ๋กœ ๋ถ„ํ• ํ•˜๋Š” ์•„์ด๋””์–ด๋ฅผ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐ ๋ช‡ ์‹œ๊ฐ„์„ ๋ณด๋ƒˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ๊ฐ™์€ ์ด์œ ๋กœ ๊ทธ๋งŒ๋’€๋Š”๋ฐ...๋ฐ•์‚ฌ๊ณผ์ •. ๊ทธ๋ž˜์„œ ๋ฐ•์‚ฌ ๊ณผ์ •์ด ์–ด๋–ป๊ฒŒ ์‚ฌ๋žŒ์˜ ์‹œ๊ฐ„์„ ๊ดด๋กญํž ์ˆ˜ ์žˆ๋Š”์ง€ ์™„์ „ํžˆ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. LOL ์••๋ ฅ์ด ์—†์Šต๋‹ˆ๋‹ค. ์ง€๊ธˆ์€ GroupShuffleSplit์„ ๋Œ€์•ˆ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๊ฑด๋ฐฐ

@bfeeny @dispink ์œ„์—์„œ ์ž‘์„ฑํ•œ ๋‘ ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์‰ฝ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ split.py ์™€ ๊ฐ™์€ ํŒŒ์ผ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ์Šคํฌ๋ฆฝํŠธ๊ฐ€ split.py ์™€ ๋™์ผํ•œ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ์œผ๋ฉด ์‚ฌ์šฉ์ž ์ฝ”๋“œ์—์„œ from split import StratifiedGroupKFold, RepeatedStratifiedGroupKFold ๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

from collections import Counter, defaultdict

import numpy as np

from sklearn.model_selection._split import _BaseKFold, _RepeatedSplits
from sklearn.utils.validation import check_random_state


class StratifiedGroupKFold(_BaseKFold):
    """Stratified K-Folds iterator variant with non-overlapping groups.

    This cross-validation object is a variation of StratifiedKFold that returns
    stratified folds with non-overlapping groups. The folds are made by
    preserving the percentage of samples for each class.

    The same group will not appear in two different folds (the number of
    distinct groups has to be at least equal to the number of folds).

    The difference between GroupKFold and StratifiedGroupKFold is that
    the former attempts to create balanced folds such that the number of
    distinct groups is approximately the same in each fold, whereas
    StratifiedGroupKFold attempts to create folds which preserve the
    percentage of samples for each class.

    Read more in the :ref:`User Guide <cross_validation>`.

    Parameters
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 2.

    shuffle : bool, default=False
        Whether to shuffle each class's samples before splitting into batches.
        Note that the samples within each split will not be shuffled.

    random_state : int or RandomState instance, default=None
        When `shuffle` is True, `random_state` affects the ordering of the
        indices, which controls the randomness of each fold for each class.
        Otherwise, leave `random_state` as `None`.
        Pass an int for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import StratifiedGroupKFold
    >>> X = np.ones((17, 2))
    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
    >>> cv = StratifiedGroupKFold(n_splits=3)
    >>> for train_idxs, test_idxs in cv.split(X, y, groups):
    ...     print("TRAIN:", groups[train_idxs])
    ...     print("      ", y[train_idxs])
    ...     print(" TEST:", groups[test_idxs])
    ...     print("      ", y[test_idxs])
    TRAIN: [2 2 4 5 5 5 5 6 6 7]
           [1 1 1 0 0 0 0 0 0 0]
     TEST: [1 1 3 3 3 8 8]
           [0 0 1 1 1 0 0]
    TRAIN: [1 1 3 3 3 4 5 5 5 5 8 8]
           [0 0 1 1 1 1 0 0 0 0 0 0]
     TEST: [2 2 6 6 7]
           [1 1 0 0 0]
    TRAIN: [1 1 2 2 3 3 3 6 6 7 8 8]
           [0 0 1 1 1 1 1 0 0 0 0 0]
     TEST: [4 5 5 5 5]
           [1 0 0 0 0]

    See also
    --------
    StratifiedKFold: Takes class information into account to build folds which
        retain class distributions (for binary or multiclass classification
        tasks).

    GroupKFold: K-fold iterator variant with non-overlapping groups.
    """

    def __init__(self, n_splits=5, shuffle=False, random_state=None):
        super().__init__(n_splits=n_splits, shuffle=shuffle,
                         random_state=random_state)

    # Implementation based on this kaggle kernel:
    # https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
    def _iter_test_indices(self, X, y, groups):
        labels_num = np.max(y) + 1
        y_counts_per_group = defaultdict(lambda: np.zeros(labels_num))
        y_distr = Counter()
        for label, group in zip(y, groups):
            y_counts_per_group[group][label] += 1
            y_distr[label] += 1

        y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num))
        groups_per_fold = defaultdict(set)

        groups_and_y_counts = list(y_counts_per_group.items())
        rng = check_random_state(self.random_state)
        if self.shuffle:
            rng.shuffle(groups_and_y_counts)

        for group, y_counts in sorted(groups_and_y_counts,
                                      key=lambda x: -np.std(x[1])):
            best_fold = None
            min_eval = None
            for i in range(self.n_splits):
                y_counts_per_fold[i] += y_counts
                std_per_label = []
                for label in range(labels_num):
                    std_per_label.append(np.std(
                        [y_counts_per_fold[j][label] / y_distr[label]
                         for j in range(self.n_splits)]))
                y_counts_per_fold[i] -= y_counts
                fold_eval = np.mean(std_per_label)
                if min_eval is None or fold_eval < min_eval:
                    min_eval = fold_eval
                    best_fold = i
            y_counts_per_fold[best_fold] += y_counts
            groups_per_fold[best_fold].add(group)

        for i in range(self.n_splits):
            test_indices = [idx for idx, group in enumerate(groups)
                            if group in groups_per_fold[i]]
            yield test_indices


class RepeatedStratifiedGroupKFold(_RepeatedSplits):
    """Repeated Stratified K-Fold cross validator.

    Repeats Stratified K-Fold with non-overlapping groups n times with
    different randomization in each repetition.

    Read more in the :ref:`User Guide <cross_validation>`.

    Parameters
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 2.

    n_repeats : int, default=10
        Number of times cross-validator needs to be repeated.

    random_state : int or RandomState instance, default=None
        Controls the generation of the random states for each repetition.
        Pass an int for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import RepeatedStratifiedGroupKFold
    >>> X = np.ones((17, 2))
    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
    >>> cv = RepeatedStratifiedGroupKFold(n_splits=2, n_repeats=2,
    ...                                   random_state=36851234)
    >>> for train_index, test_index in cv.split(X, y, groups):
    ...     print("TRAIN:", groups[train_idxs])
    ...     print("      ", y[train_idxs])
    ...     print(" TEST:", groups[test_idxs])
    ...     print("      ", y[test_idxs])
    TRAIN: [2 2 4 5 5 5 5 8 8]
           [1 1 1 0 0 0 0 0 0]
     TEST: [1 1 3 3 3 6 6 7]
           [0 0 1 1 1 0 0 0]
    TRAIN: [1 1 3 3 3 6 6 7]
           [0 0 1 1 1 0 0 0]
     TEST: [2 2 4 5 5 5 5 8 8]
           [1 1 1 0 0 0 0 0 0]
    TRAIN: [3 3 3 4 7 8 8]
           [1 1 1 1 0 0 0]
     TEST: [1 1 2 2 5 5 5 5 6 6]
           [0 0 1 1 0 0 0 0 0 0]
    TRAIN: [1 1 2 2 5 5 5 5 6 6]
           [0 0 1 1 0 0 0 0 0 0]
     TEST: [3 3 3 4 7 8 8]
           [1 1 1 1 0 0 0]

    Notes
    -----
    Randomized CV splitters may return different results for each call of
    split. You can make the results identical by setting `random_state`
    to an integer.

    See also
    --------
    RepeatedStratifiedKFold: Repeats Stratified K-Fold n times.
    """

    def __init__(self, n_splits=5, n_repeats=10, random_state=None):
        super().__init__(StratifiedGroupKFold, n_splits=n_splits,
                         n_repeats=n_repeats, random_state=random_state)

@hermidalc ๊ธ์ •์ ์ธ ๋‹ต๋ณ€ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค!
์„ค๋ช…ํ•ด์ฃผ์‹ ๋Œ€๋กœ ๋น ๋ฅด๊ฒŒ ์ฑ„ํƒํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ํ›ˆ๋ จ ๋˜๋Š” ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ๋ฐ์ดํ„ฐ๋งŒ ์žˆ๋Š” ๋ถ„ํ• ๋งŒ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฝ”๋“œ ์„ค๋ช…์„ ์ดํ•ดํ•˜๋Š” ํ•œ ํ›ˆ๋ จ ์„ธํŠธ์™€ ํ…Œ์ŠคํŠธ ์„ธํŠธ ์‚ฌ์ด์˜ ๋น„์œจ์„ ์ง€์ •ํ•˜๋Š” ๋งค๊ฐœ ๋ณ€์ˆ˜๊ฐ€ ์—†์Šต๋‹ˆ๊นŒ?
๊ณ„์ธตํ™”, ๊ทธ๋ฃน ์ œ์–ด ๋ฐ ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋น„์œจ ๊ฐ„์˜ ์ถฉ๋Œ์ด๋ผ๋Š” ๊ฒƒ์„ ์•Œ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ณ„์†ํ•˜๊ธฐ๋ฅผ ํฌ๊ธฐํ–ˆ์Šต๋‹ˆ๋‹ค... ํ•˜์ง€๋งŒ ์—ฌ์ „ํžˆ ํ•ด๊ฒฐํ•ด์•ผ ํ•  ํƒ€ํ˜‘์ ์„ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
image

์ง„์ •์œผ๋กœ

@hermidalc ๊ธ์ •์ ์ธ ๋‹ต๋ณ€ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค!
์„ค๋ช…ํ•ด์ฃผ์‹ ๋Œ€๋กœ ๋น ๋ฅด๊ฒŒ ์ฑ„ํƒํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ํ›ˆ๋ จ ๋˜๋Š” ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ๋ฐ์ดํ„ฐ๋งŒ ์žˆ๋Š” ๋ถ„ํ• ๋งŒ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฝ”๋“œ ์„ค๋ช…์„ ์ดํ•ดํ•˜๋Š” ํ•œ ํ›ˆ๋ จ ์„ธํŠธ์™€ ํ…Œ์ŠคํŠธ ์„ธํŠธ ์‚ฌ์ด์˜ ๋น„์œจ์„ ์ง€์ •ํ•˜๋Š” ๋งค๊ฐœ ๋ณ€์ˆ˜๊ฐ€ ์—†์Šต๋‹ˆ๊นŒ?
๊ณ„์ธตํ™”, ๊ทธ๋ฃน ์ œ์–ด ๋ฐ ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋น„์œจ ๊ฐ„์˜ ์ถฉ๋Œ์ด๋ผ๋Š” ๊ฒƒ์„ ์•Œ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ณ„์†ํ•˜๊ธฐ๋ฅผ ํฌ๊ธฐํ–ˆ์Šต๋‹ˆ๋‹ค... ํ•˜์ง€๋งŒ ์—ฌ์ „ํžˆ ํ•ด๊ฒฐํ•ด์•ผ ํ•  ํƒ€ํ˜‘์ ์„ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ…Œ์ŠคํŠธํ•˜๊ธฐ ์œ„ํ•ด split.py ๋ฅผ ๋งŒ๋“ค๊ณ  ์ด ์˜ˆ์ œ๋ฅผ ipython์—์„œ ์‹คํ–‰ํ–ˆ๋Š”๋ฐ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ๋‚˜๋Š” ์˜ค๋žซ๋™์•ˆ ๋‚ด ์ž‘์—…์—์„œ ์ด๋Ÿฌํ•œ ์‚ฌ์šฉ์ž ์ •์˜ CV ๋ฐ˜๋ณต์ž๋ฅผ ์‚ฌ์šฉํ•ด ์™”์œผ๋ฉฐ ๋ฌธ์ œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. BTW 0.23.x๊ฐ€ ์•„๋‹Œ scikit-learn 0.22.2 ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์œผ๋ฏ€๋กœ ์ด๊ฒƒ์ด ๋ฌธ์ œ์˜ ์›์ธ์ธ์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์•„๋ž˜์—์„œ ์ด ์˜ˆ๋ฅผ ์‹คํ–‰ํ•˜๊ณ  ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ํ™•์ธํ•ด ์ฃผ์‹œ๊ฒ ์Šต๋‹ˆ๊นŒ? ๊ฐ€๋Šฅํ•˜๋‹ค๋ฉด ์ž‘์—…์— y ๋ฐ groups ๊ฐ€ ํฌํ•จ๋œ ๊ฒƒ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

In [6]: import numpy as np 
   ...: from split import StratifiedGroupKFold 
   ...:  
   ...: X = np.ones((17, 2)) 
   ...: y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 
   ...: groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8]) 
   ...: cv = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=777) 
   ...: for train_idxs, test_idxs in cv.split(X, y, groups): 
   ...:     print("TRAIN:", groups[train_idxs]) 
   ...:     print("      ", y[train_idxs]) 
   ...:     print(" TEST:", groups[test_idxs]) 
   ...:     print("      ", y[test_idxs]) 
   ...:                                                                                                                                                                                                    
TRAIN: [2 2 4 5 5 5 5 6 6 7]
       [1 1 1 0 0 0 0 0 0 0]
 TEST: [1 1 3 3 3 8 8]
       [0 0 1 1 1 0 0]
TRAIN: [1 1 3 3 3 4 5 5 5 5 8 8]
       [0 0 1 1 1 1 0 0 0 0 0 0]
 TEST: [2 2 6 6 7]
       [1 1 0 0 0]
TRAIN: [1 1 2 2 3 3 3 6 6 7 8 8]
       [0 0 1 1 1 1 1 0 0 0 0 0]
 TEST: [4 5 5 5 5]
       [1 0 0 0 0]

@hermidalc ์ด ๊ธฐ๋Šฅ์— ์ •๊ธฐ์ ์ธ ๊ด€์‹ฌ์ด ์žˆ๋Š” ๊ฒƒ ๊ฐ™์œผ๋ฉฐ ์šฐ๋ฆฌ๋Š”
๋‹น์‹ ์ด ์‹ ๊ฒฝ ์“ฐ์ง€ ์•Š๋Š”๋‹ค๋ฉด ๋๋‚ด์ค„ ์‚ฌ๋žŒ์„ ์ฐพ์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

@hermidalc '๋™์ผํ•œ ๊ทธ๋ฃน์˜ ๋ชจ๋“  ์ƒ˜ํ”Œ์— ๋™์ผํ•œ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”์ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.' ๋ถ„๋ช…ํžˆ ๊ทธ๊ฒƒ์ด ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค. ๊ฐ™์€ ๊ทธ๋ฃน์˜ ์ƒ˜ํ”Œ์ด ๊ฐ™์€ ํด๋ž˜์Šค๋ฅผ ๊ณต์œ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์Œ... ๊ฐœ๋ฐœ์˜ ๋˜ ๋‹ค๋ฅธ ์ง€์ ์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.
์–ด์จŒ๋“  ๋Œ€๋‹จํžˆ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.

@hermidalc '๋™์ผํ•œ ๊ทธ๋ฃน์˜ ๋ชจ๋“  ์ƒ˜ํ”Œ์— ๋™์ผํ•œ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”์ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.' ๋ถ„๋ช…ํžˆ ๊ทธ๊ฒƒ์ด ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค. ๊ฐ™์€ ๊ทธ๋ฃน์˜ ์ƒ˜ํ”Œ์ด ๊ฐ™์€ ํด๋ž˜์Šค๋ฅผ ๊ณต์œ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์Œ... ๊ฐœ๋ฐœ์˜ ๋˜ ๋‹ค๋ฅธ ์ง€์ ์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.
์–ด์จŒ๋“  ๋Œ€๋‹จํžˆ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ, ์ด๊ฒƒ์€ ์—ฌ๊ธฐ์˜ ๋‹ค์–‘ํ•œ ์Šค๋ ˆ๋“œ์—์„œ ๋…ผ์˜๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์œ ์šฉํ•œ ๋˜ ๋‹ค๋ฅธ ๋” ๋ณต์žกํ•œ ์‚ฌ์šฉ ์‚ฌ๋ก€์ด์ง€๋งŒ ๋‚˜์™€ ๊ฐ™์€ ๋งŽ์€ ์‚ฌ๋žŒ๋“ค์€ ํ˜„์žฌ ํ•ด๋‹น ์‚ฌ์šฉ ์‚ฌ๋ก€๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์ง€๋งŒ ๊ทธ๋ฃน์„ ํ•จ๊ป˜ ์œ ์ง€ํ•˜๋ฉด์„œ ์ƒ˜ํ”Œ์„ ๊ณ„์ธตํ™”ํ•˜๋Š” ๋ฌด์–ธ๊ฐ€๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์œ„ ์ฝ”๋“œ์˜ ์š”๊ตฌ ์‚ฌํ•ญ์€ ๊ฐ ๊ทธ๋ฃน์˜ ๋ชจ๋“  ์ƒ˜ํ”Œ์ด ๋™์ผํ•œ ํด๋ž˜์Šค์— ์†ํ•ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์‚ฌ์‹ค @dispink ์ œ๊ฐ€ ํ‹€๋ ธ์Šต๋‹ˆ๋‹ค. ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๊ทธ๋ฃน์˜ ๋ชจ๋“  ๊ตฌ์„ฑ์›์ด ๊ฐ™์€ ํด๋ž˜์Šค์— ์†ํ•  ๊ฒƒ์„ ์š”๊ตฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด:

In [2]: X = np.ones((17, 2)) 
   ...: y =      np.array([0, 2, 1, 1, 2, 0, 0, 1, 2, 1, 1, 1, 0, 2, 0, 1, 0]) 
   ...: groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8]) 
   ...: cv = StratifiedGroupKFold(n_splits=3) 
   ...: for train_idxs, test_idxs in cv.split(X, y, groups): 
   ...:     print("TRAIN:", groups[train_idxs]) 
   ...:     print("      ", y[train_idxs]) 
   ...:     print(" TEST:", groups[test_idxs]) 
   ...:     print("      ", y[test_idxs]) 
   ...:                                                                                                                                                                                                    
TRAIN: [1 1 2 2 3 3 3 4 8 8]
       [0 2 1 1 2 0 0 1 1 0]
 TEST: [5 5 5 5 6 6 7]
       [2 1 1 1 0 2 0]
TRAIN: [1 1 4 5 5 5 5 6 6 7 8 8]
       [0 2 1 2 1 1 1 0 2 0 1 0]
 TEST: [2 2 3 3 3]
       [1 1 2 0 0]
TRAIN: [2 2 3 3 3 5 5 5 5 6 6 7]
       [1 1 2 0 0 2 1 1 1 0 2 0]
 TEST: [1 1 4 8 8]
       [0 2 1 1 0]

๋”ฐ๋ผ์„œ ์Šคํฌ๋ฆฐ์ƒท์œผ๋กœ๋„ ๋ฐ์ดํ„ฐ ๋ ˆ์ด์•„์›ƒ์ด ๋ฌด์—‡์ธ์ง€, ์–ด๋–ค ์ผ์ด ์ผ์–ด๋‚ ์ง€ ์ œ๋Œ€๋กœ ๋ณผ ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ์— ๋ฌด์Šจ ์ผ์ด ์ผ์–ด๋‚˜๊ณ  ์žˆ๋Š”์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค. scikit-learn ๋ฒ„์ „ ๋ฌธ์ œ(์ €๋Š” 0.22.2๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์—)๊ฐ€ ์•„๋‹Œ์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ์—ฌ๊ธฐ์—์„œ ๋ณด์—ฌ์ค€ ์˜ˆ์ œ๋ฅผ ๋จผ์ € ์žฌํ˜„ํ•˜๊ณ  ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด ์ž‘์€ ๋ถ€๋ถ„๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์™€ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค. ~104k ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

@hermidalc ๋‹ต๋ณ€ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค!
์‹ค์ œ๋กœ ์œ„์˜ ๊ฒฐ๊ณผ๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์ง€๊ธˆ์€ ๋” ์ž‘์€ ๋ฐ์ดํ„ฐ๋กœ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

+1

๋‚ด๊ฐ€ ์ด ๋ฌธ์ œ๋ฅผ ๊ณจ๋ผ๋„ ์ƒ๊ด€์—†๋‚˜์š”?
https://github.com/scikit-learn/scikit-learn/issues/13621#issuecomment -600894432์™€ ํ•จ๊ป˜ #15239์— ์ด๋ฏธ ๊ตฌํ˜„์ด ์žˆ๊ณ  ๋‹จ์œ„ ํ…Œ์ŠคํŠธ๋งŒ ๋‚จ์€ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰