optunaでearly-stoppingする
optunaの最適化を終了する方法として主に以下の2つが用いられる。
- ある回数試行を繰り返したら終了する (
n_trials
) - ある時間経過したら終了する (
timeout
)
上記の方法とは別に、ある一定の条件を満たしたときに終了するやり方を以下に示す。
Sponsored by Google AdSense
# モジュールのインポート
import sys
import optuna
# Pythonのバージョン
print(sys.version)
3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00)
[Clang 13.0.1 ]
# モジュールのバージョン
optuna.__version__
'3.0.5'
条件を満たしたときに最適化を終える
パターン1: pruneを使用する
枝刈り (prune) 機能を応用して実装する。
ある条件を満たしたときにpruneして、pruneをしたら最適化を終了するcallbackを適用する。
callbackは、optuna第1引数にoptuna.study.Study
、第2引数にoptuna.trial.FrozenTrial
を取る関数を定義することで定義できる。
参考: Callback for Study.optimize - optuna documentation
以下のように、pruneしたときに最適化を終了するcallbackを定義する。
def earlystopping_callback(
study: optuna.study.Study, trial: optuna.trial.FrozenTrial
) -> None:
if trial.state == optuna.trial.TrialState.PRUNED:
study.stop()
適当な2次関数の最小化問題で適用してみる。
def objective(trial: optuna.trial.Trial) -> float:
"""最小化する目的関数"""
x = trial.suggest_float("x", -10, 10)
# 2次関数の値を計算する
value = (x - 2) ** 2
# early stoppingする条件
# 誤差が1e-2以下になったらearly stoppingする
if abs(value) < 1e-2:
raise optuna.TrialPruned()
return value
study = optuna.create_study(
sampler=optuna.samplers.TPESampler(seed=334), direction="minimize"
)
study.optimize(objective, n_trials=100, callbacks=[earlystopping_callback])
[32m[I 2023-05-21 23:05:01,196][0m A new study created in memory with name: no-name-a2e23824-7626-421e-a505-bacb7ecd89d5[0m
[32m[I 2023-05-21 23:05:01,197][0m Trial 0 finished with value: 52.37805771890191 and parameters: {'x': -5.23726866427535}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,198][0m Trial 1 finished with value: 128.82917906550037 and parameters: {'x': -9.350294228146703}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,198][0m Trial 2 finished with value: 90.52852818387217 and parameters: {'x': -7.51464808513022}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,199][0m Trial 3 finished with value: 73.77157162105132 and parameters: {'x': -6.589037875166888}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,199][0m Trial 4 finished with value: 38.390210975765406 and parameters: {'x': -4.195983455091323}. Best is trial 4 with value: 38.390210975765406.[0m
[32m[I 2023-05-21 23:05:01,200][0m Trial 5 finished with value: 19.89993207283355 and parameters: {'x': 6.460933991086794}. Best is trial 5 with value: 19.89993207283355.[0m
[32m[I 2023-05-21 23:05:01,200][0m Trial 6 finished with value: 15.894152965494595 and parameters: {'x': 5.986747165985648}. Best is trial 6 with value: 15.894152965494595.[0m
[32m[I 2023-05-21 23:05:01,201][0m Trial 7 finished with value: 97.40470532348252 and parameters: {'x': -7.869382215897939}. Best is trial 6 with value: 15.894152965494595.[0m
[32m[I 2023-05-21 23:05:01,201][0m Trial 8 finished with value: 23.030596813196325 and parameters: {'x': 6.799020401414889}. Best is trial 6 with value: 15.894152965494595.[0m
[32m[I 2023-05-21 23:05:01,202][0m Trial 9 pruned. [0m
# 最後のtrialを取得
best_trial = study.get_trials()[-1]
print(best_trial.params)
{'x': 1.9269908569454213}
上記の実装の欠点として、early stoppingの条件を満たした場合そのtrialの結果が記録されない。
print(f"best value: {best_trial.value}")
best value: None
パターン2: pruneしない
したがって、条件を満たすときに直接study.stop()
してあげれば良い。
def earlystopping_callback(
study: optuna.study.Study, trial: optuna.trial.FrozenTrial
) -> None:
if (
trial.state == optuna.trial.TrialState.COMPLETE
and abs(trial.value) < 1e-2
):
study.stop()
def objective(trial: optuna.trial.Trial) -> float:
"""最小化する目的関数"""
x = trial.suggest_float("x", -10, 10)
# 2次関数の値を計算する
value = (x - 2) ** 2
return value
study = optuna.create_study(
sampler=optuna.samplers.TPESampler(seed=334), direction="minimize"
)
study.optimize(objective, n_trials=100, callbacks=[earlystopping_callback])
[32m[I 2023-05-21 23:05:01,340][0m A new study created in memory with name: no-name-402f0e16-d918-4cca-9bce-7e693f38341f[0m
[32m[I 2023-05-21 23:05:01,341][0m Trial 0 finished with value: 52.37805771890191 and parameters: {'x': -5.23726866427535}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,342][0m Trial 1 finished with value: 128.82917906550037 and parameters: {'x': -9.350294228146703}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,342][0m Trial 2 finished with value: 90.52852818387217 and parameters: {'x': -7.51464808513022}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,343][0m Trial 3 finished with value: 73.77157162105132 and parameters: {'x': -6.589037875166888}. Best is trial 0 with value: 52.37805771890191.[0m
[32m[I 2023-05-21 23:05:01,343][0m Trial 4 finished with value: 38.390210975765406 and parameters: {'x': -4.195983455091323}. Best is trial 4 with value: 38.390210975765406.[0m
[32m[I 2023-05-21 23:05:01,344][0m Trial 5 finished with value: 19.89993207283355 and parameters: {'x': 6.460933991086794}. Best is trial 5 with value: 19.89993207283355.[0m
[32m[I 2023-05-21 23:05:01,344][0m Trial 6 finished with value: 15.894152965494595 and parameters: {'x': 5.986747165985648}. Best is trial 6 with value: 15.894152965494595.[0m
[32m[I 2023-05-21 23:05:01,344][0m Trial 7 finished with value: 97.40470532348252 and parameters: {'x': -7.869382215897939}. Best is trial 6 with value: 15.894152965494595.[0m
[32m[I 2023-05-21 23:05:01,345][0m Trial 8 finished with value: 23.030596813196325 and parameters: {'x': 6.799020401414889}. Best is trial 6 with value: 15.894152965494595.[0m
[32m[I 2023-05-21 23:05:01,345][0m Trial 9 finished with value: 0.005330334969563941 and parameters: {'x': 1.9269908569454213}. Best is trial 9 with value: 0.005330334969563941.[0m
print(f"best params: {study.best_params}")
print(f"best value: {study.best_value}")
best params: {'x': 1.9269908569454213}
best value: 0.005330334969563941
これらの方法を利用することで、条件を満たすようなパラメータが見つかったときに最適化を終了させることができる。