Сохранение обученной модели в PyTorch: лучшие практики

Пройдите тест, узнайте какой профессии подходите

Я предпочитаю
0%
Работать самостоятельно и не зависеть от других
Работать в команде и рассчитывать на помощь коллег
Организовывать и контролировать процесс работы

Быстрый ответ

Для сохранения модели в PyTorch используйте функцию torch.save(). Эта функция предназначена для сохранения параметров модели:

Python
Скопировать код
torch.save(model.state_dict(), 'my_model.pth')

Чтобы загрузить модель, примените функцию model.load_state_dict(). Это тоже достаточно просто:

Python
Скопировать код
model = MyModel() # заранее определите класс модели
model.load_state_dict(torch.load('my_model.pth'))

Этот способ подразумевает сохранение исключительно обучаемых параметров, благодаря чему получается компактный и переносимый файл модели.

Кинга Идем в IT: пошаговый план для смены профессии

Сохранение и загрузка: Подробное руководство

Паузы в долгих обучающих сессиях

При длительных сессиях обучения удобно использовать контрольные точки, в которых сохраняется больше данных, чем при обычном сохранении состояния модели:

Python
Скопировать код
checkpoint = {
    'epoch': epoch,  # количество выполненных эпох
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,  # значение функции потерь
    # можно добавлять и другие метрики, нужные для ваших экспериментов
}
torch.save(checkpoint, 'pitstop.pth')

При загрузке контрольной точки сначала следует инициализировать модель и оптимизатор:

Python
Скопировать код
checkpoint = torch.load('pitstop.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Если загрузка происходит в середине эпохи…
epoch = checkpoint['epoch']
loss = checkpoint['loss']

При загрузке модели для использования в инференсе не забывайте переключить её в режим model.eval().

Подробнее об этом расскажет наш спикер на видео
skypro youtube speaker

Сохранение полной модели: когда важен контекст

Если есть необходимость сохранить всю модель, особенно в случае потери исходного кода модели, можно использовать:

Python
Скопировать код
torch.save(model, 'the_whole_shebang.pth')

Для восстановления модели используется:

Python
Скопировать код
model = torch.load('the_whole_shebang.pth')

Совместимость между разными версиями PyTorch

Рекомендуется указывать используемую версию PyTorch при сохранении моделей для избежания проблем совместимости в будущем.

Особенности сохранения при распределённом обучении

Если вы используете модель в контексте DataParallel или DistributedDataParallel, сохраняйте параметры следующим образом:

Python
Скопировать код
torch.save(model.module.state_dict(), 'all_for_one.pth')

Загрузка параметров в модель другой архитектуры

Для загрузки параметров в модель другой архитектуры надо сначала определить класс модели, а потом загрузить состояние.

Визуализация

Сохранение и восстановление модели в PyTorch состоит из следующих этапов:

Markdown
Скопировать код
1. Обучение: модель обучается выполнять некоторые задачи.
2. Сериализация: применяем `torch.save` для упаковки модели.
3. Создание файла: присваиваем модели имя для удобства управления.
4. Хранение: сохраняем модель на любой носитель информации.
5. Развертывание: модель готова к использованию.

Примечание: Сериализация здесь выступает как упаковка чемодана перед путешествием!

Полезные материалы

Сохранять модель напрямую через model.parameters() не стоит, так как это может вызвать ошибки. Используйте state_dict(). Дайте вашим файлам расширения .pt или .pth для следования конвенциям PyTorch. Когда работаете с DataParallel, не забывайте правильно обработать модель перед загрузкой состояния. Полезно сохранять в контрольных точках не только состояние модели и оптимизатора, но и прошедшие эпохи и различные метрики. Периодически проверяйте свои модели на сериализуемость и ищите баланс между объемом хранения и удобством развертывания.

Проверь как ты усвоил материалы статьи
Пройди тест и узнай насколько ты лучше других читателей
Какой метод в PyTorch используется для сохранения параметров модели?
1 / 5