optunaの最適化を終了する方法として主に以下の2つが用いられる。

  1. ある回数試行を繰り返したら終了する (n_trials)
  2. ある時間経過したら終了する (timeout)

上記の方法とは別に、ある一定の条件を満たしたときに終了するやり方を以下に示す。


# モジュールのインポート
import sys
import optuna

# Pythonのバージョン
print(sys.version)

3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00)
[Clang 13.0.1 ]
# モジュールのバージョン
optuna.__version__

'3.0.5'

条件を満たしたときに最適化を終える

パターン1: pruneを使用する

枝刈り (prune) 機能を応用して実装する。

ある条件を満たしたときにpruneして、pruneをしたら最適化を終了するcallbackを適用する。

callbackは、optuna第1引数にoptuna.study.Study、第2引数にoptuna.trial.FrozenTrialを取る関数を定義することで定義できる。

参考: Callback for Study.optimize - optuna documentation

以下のように、pruneしたときに最適化を終了するcallbackを定義する。

def earlystopping_callback(
    study: optuna.study.Study, trial: optuna.trial.FrozenTrial
) -> None:
    if trial.state == optuna.trial.TrialState.PRUNED:
        study.stop()

適当な2次関数の最小化問題で適用してみる。

def objective(trial: optuna.trial.Trial) -> float:
    """最小化する目的関数"""
    x = trial.suggest_float("x", -10, 10)
    # 2次関数の値を計算する
    value = (x - 2) ** 2

    # early stoppingする条件
    # 誤差が1e-2以下になったらearly stoppingする
    if abs(value) < 1e-2:
        raise optuna.TrialPruned()
    return value

study = optuna.create_study(
    sampler=optuna.samplers.TPESampler(seed=334), direction="minimize"
)
study.optimize(objective, n_trials=100, callbacks=[earlystopping_callback])

[I 2023-05-21 23:05:01,196] A new study created in memory with name: no-name-a2e23824-7626-421e-a505-bacb7ecd89d5
[I 2023-05-21 23:05:01,197] Trial 0 finished with value: 52.37805771890191 and parameters: {'x': -5.23726866427535}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,198] Trial 1 finished with value: 128.82917906550037 and parameters: {'x': -9.350294228146703}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,198] Trial 2 finished with value: 90.52852818387217 and parameters: {'x': -7.51464808513022}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,199] Trial 3 finished with value: 73.77157162105132 and parameters: {'x': -6.589037875166888}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,199] Trial 4 finished with value: 38.390210975765406 and parameters: {'x': -4.195983455091323}. Best is trial 4 with value: 38.390210975765406.
[I 2023-05-21 23:05:01,200] Trial 5 finished with value: 19.89993207283355 and parameters: {'x': 6.460933991086794}. Best is trial 5 with value: 19.89993207283355.
[I 2023-05-21 23:05:01,200] Trial 6 finished with value: 15.894152965494595 and parameters: {'x': 5.986747165985648}. Best is trial 6 with value: 15.894152965494595.
[I 2023-05-21 23:05:01,201] Trial 7 finished with value: 97.40470532348252 and parameters: {'x': -7.869382215897939}. Best is trial 6 with value: 15.894152965494595.
[I 2023-05-21 23:05:01,201] Trial 8 finished with value: 23.030596813196325 and parameters: {'x': 6.799020401414889}. Best is trial 6 with value: 15.894152965494595.
[I 2023-05-21 23:05:01,202] Trial 9 pruned. 
# 最後のtrialを取得
best_trial = study.get_trials()[-1]
print(best_trial.params)

{'x': 1.9269908569454213}

上記の実装の欠点として、early stoppingの条件を満たした場合そのtrialの結果が記録されない。

print(f"best value: {best_trial.value}")

best value: None

パターン2: pruneしない

したがって、条件を満たすときに直接study.stop()してあげれば良い。

def earlystopping_callback(
    study: optuna.study.Study, trial: optuna.trial.FrozenTrial
) -> None:
    if (
        trial.state == optuna.trial.TrialState.COMPLETE
        and abs(trial.value) < 1e-2
    ):
        study.stop()

def objective(trial: optuna.trial.Trial) -> float:
    """最小化する目的関数"""
    x = trial.suggest_float("x", -10, 10)
    # 2次関数の値を計算する
    value = (x - 2) ** 2
    return value

study = optuna.create_study(
    sampler=optuna.samplers.TPESampler(seed=334), direction="minimize"
)
study.optimize(objective, n_trials=100, callbacks=[earlystopping_callback])

[I 2023-05-21 23:05:01,340] A new study created in memory with name: no-name-402f0e16-d918-4cca-9bce-7e693f38341f
[I 2023-05-21 23:05:01,341] Trial 0 finished with value: 52.37805771890191 and parameters: {'x': -5.23726866427535}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,342] Trial 1 finished with value: 128.82917906550037 and parameters: {'x': -9.350294228146703}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,342] Trial 2 finished with value: 90.52852818387217 and parameters: {'x': -7.51464808513022}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,343] Trial 3 finished with value: 73.77157162105132 and parameters: {'x': -6.589037875166888}. Best is trial 0 with value: 52.37805771890191.
[I 2023-05-21 23:05:01,343] Trial 4 finished with value: 38.390210975765406 and parameters: {'x': -4.195983455091323}. Best is trial 4 with value: 38.390210975765406.
[I 2023-05-21 23:05:01,344] Trial 5 finished with value: 19.89993207283355 and parameters: {'x': 6.460933991086794}. Best is trial 5 with value: 19.89993207283355.
[I 2023-05-21 23:05:01,344] Trial 6 finished with value: 15.894152965494595 and parameters: {'x': 5.986747165985648}. Best is trial 6 with value: 15.894152965494595.
[I 2023-05-21 23:05:01,344] Trial 7 finished with value: 97.40470532348252 and parameters: {'x': -7.869382215897939}. Best is trial 6 with value: 15.894152965494595.
[I 2023-05-21 23:05:01,345] Trial 8 finished with value: 23.030596813196325 and parameters: {'x': 6.799020401414889}. Best is trial 6 with value: 15.894152965494595.
[I 2023-05-21 23:05:01,345] Trial 9 finished with value: 0.005330334969563941 and parameters: {'x': 1.9269908569454213}. Best is trial 9 with value: 0.005330334969563941.
print(f"best params: {study.best_params}")
print(f"best value: {study.best_value}")

best params: {'x': 1.9269908569454213}
best value: 0.005330334969563941

これらの方法を利用することで、条件を満たすようなパラメータが見つかったときに最適化を終了させることができる。

参考