Сохранение обученной модели в PyTorch: лучшие практики
Пройдите тест, узнайте какой профессии подходите
Быстрый ответ
Для сохранения модели в PyTorch используйте функцию torch.save()
. Эта функция предназначена для сохранения параметров модели:
torch.save(model.state_dict(), 'my_model.pth')
Чтобы загрузить модель, примените функцию model.load_state_dict()
. Это тоже достаточно просто:
model = MyModel() # заранее определите класс модели
model.load_state_dict(torch.load('my_model.pth'))
Этот способ подразумевает сохранение исключительно обучаемых параметров, благодаря чему получается компактный и переносимый файл модели.
Сохранение и загрузка: Подробное руководство
Паузы в долгих обучающих сессиях
При длительных сессиях обучения удобно использовать контрольные точки, в которых сохраняется больше данных, чем при обычном сохранении состояния модели:
checkpoint = {
'epoch': epoch, # количество выполненных эпох
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss, # значение функции потерь
# можно добавлять и другие метрики, нужные для ваших экспериментов
}
torch.save(checkpoint, 'pitstop.pth')
При загрузке контрольной точки сначала следует инициализировать модель и оптимизатор:
checkpoint = torch.load('pitstop.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Если загрузка происходит в середине эпохи…
epoch = checkpoint['epoch']
loss = checkpoint['loss']
При загрузке модели для использования в инференсе не забывайте переключить её в режим model.eval()
.
Сохранение полной модели: когда важен контекст
Если есть необходимость сохранить всю модель, особенно в случае потери исходного кода модели, можно использовать:
torch.save(model, 'the_whole_shebang.pth')
Для восстановления модели используется:
model = torch.load('the_whole_shebang.pth')
Совместимость между разными версиями PyTorch
Рекомендуется указывать используемую версию PyTorch при сохранении моделей для избежания проблем совместимости в будущем.
Особенности сохранения при распределённом обучении
Если вы используете модель в контексте DataParallel
или DistributedDataParallel
, сохраняйте параметры следующим образом:
torch.save(model.module.state_dict(), 'all_for_one.pth')
Загрузка параметров в модель другой архитектуры
Для загрузки параметров в модель другой архитектуры надо сначала определить класс модели, а потом загрузить состояние.
Визуализация
Сохранение и восстановление модели в PyTorch состоит из следующих этапов:
1. Обучение: модель обучается выполнять некоторые задачи.
2. Сериализация: применяем `torch.save` для упаковки модели.
3. Создание файла: присваиваем модели имя для удобства управления.
4. Хранение: сохраняем модель на любой носитель информации.
5. Развертывание: модель готова к использованию.
Примечание: Сериализация здесь выступает как упаковка чемодана перед путешествием!
Полезные материалы
Сохранять модель напрямую через model.parameters()
не стоит, так как это может вызвать ошибки. Используйте state_dict()
. Дайте вашим файлам расширения .pt
или .pth
для следования конвенциям PyTorch. Когда работаете с DataParallel
, не забывайте правильно обработать модель перед загрузкой состояния. Полезно сохранять в контрольных точках не только состояние модели и оптимизатора, но и прошедшие эпохи и различные метрики. Периодически проверяйте свои модели на сериализуемость и ищите баланс между объемом хранения и удобством развертывания.