Сохранение промежуточных результатов обучения

В процессе обучения моделей сложилась практика сохранять промежуточные результаты, или контрольные точки (checkpoints) обучения. Это информация о конфигурации модели, ее переменных, градиентах, весах и иные дополнительные сведения. Сохранение промежуточных результатов позволит пользователям:

  • Не потерять результаты обучения в случае остановки задачи по какой-то причине (баланс, падение задачи). Пользователь сможет возобновить обучение модели из последнего сохраненного состояния.

  • Делиться обученными моделями с коллегами, чтобы те могли восстановить объект с моделью без повторного обучения.

Рекомендуется сохранять промежуточные результаты с некоторой разумной периодичностью, например, в конце каждой эпохи или после окончания итерации по небольшим блокам обучения (batches). В процессе обучения модели на кластере Christofari промежуточные результаты сохраняются в рабочей директории пользователя /home/jovyan/. Их можно скачать через веб-интерфейс Jupyter Notebook/JupyterLab или скопировать из локально доступной файловой системы в хранилище S3. Подробнее о выгрузке промежуточных результатов обучения на S3 см. в примере.

Рассмотрим, как сохранять промежуточные результаты обучения для наиболее распространенных фреймворков:

Keras

Для сохранения промежуточных результатов и фиксации состояния модели в определенный момент времени используется механизм обратных вызовов (callbacks), как экземпляр класса ModelCheckpoint.

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
    if hvd.rank() == 0:
       callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join('checkpoints','checkpoint-{epoch}.h5'))

# Train the model.
# Horovod: adjust number of steps based on number of GPUs.
mnist_model.fit(dataset, steps_per_epoch=500 // hvd.size(), callbacks=callbacks, epochs=24, verbose=verbose)

TensorFlow

Для восстановления сеанса Tensorflow можно использовать конструктор MonitoredTrainingSession() с аргументом checkpoint_dir.

#Horovod: Save checkpoints only on worker 0 to prevent other workers from corrupting them.
checkpoint_dir = '/tmp/train_logs' if hvd.rank() == 0 else None

# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                       <some_other_variables>) as sess:

PyTorch

Сохранение параметров модели с помощью torch.save().

#Save checkpoints only on worker 0 to prevent other workers from corrupting them.
if hvd.rank() == 0
    torch.save(the_model.state_dict(), PATH)

Ниже показано, как возобновить обучение модели из последнего сохраненного состояния. За основу взят следующий пример.

num_epochs = 3
print(f'Start train {num_epochs} epochs total')

# Loading from checkpoint
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
last_epoch = 0
import os
for root, dirs, files in os.walk(BASE_DIR.joinpath('logs')):
    saved_models = [model_filename for model_filename in files if ".bin" in model_filename]

if saved_models:
    checkpoint = torch.load(os.path.join(root, saved_models[-1]))
    clf.load_state_dict(checkpoint['model_state_dict']) #loading model weights and other training parameters
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    last_epoch = checkpoint['epoch']
    print(f"Continue training from {last_epoch} epoch")

# Start training
mlflow.set_tracking_uri('file:/home/jovyan/mlruns')
mlflow.set_experiment("pytorch_tensorboard_mlflow.ipynb")
with mlflow.start_run(nested=True) as run:
    for epoch in range(num_epochs):
        if last_epoch:
            epoch += last_epoch + 1

        print("Epoch %d" % epoch)
        train(epoch, clf, optimizer, writer)
        test(epoch, clf, writer)
        # Save checkpoint every epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': clf.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, BASE_DIR.joinpath('logs/log_' + current_time + f"/model_epoch_{epoch}.bin"))
        writer.close()