tkherox blog

データサイエンスおよびソフトウェア開発、たまに育児についての話を書いています

Pytorchにおけるモデル保存の使い分け

はじめに

TorchServeを利用してサービングを実施する際にモデルの保存方法についていくつかパターンがあり,TorchServeで保存したモデルを読み込む際にうまく動作しないといった事があったのでしっかり違いを把握しようと思ってこの記事を書いています.この記事を読んでくださっている人の中にもよく分からずに何となくPytorchにおけるモデル保存を実施している人もいるかと思いますのでそう言った方の参考になればと思います.ちなみにPytorchのバージョンは1.7.0を前提として話を進めます.

モデル保存パターン

まず保存パターンについて説明していきます.
公式では大きく2つのパターンでのモデル保存を解説しています.また,そのほかにもTorchScriptを利用したモデルの保存方法もあります.方法別で記載すると以下のパターンがモデル保存方法として実現する方法としてあります.

  1. state_dictを利用したモデルの保存/読み込み
  2. entire のモデル保存/読み込み
  3. TorchScriptによるモデル保存/読み込み

pytorch.org

次から各パターンの特徴を実装例を交えて解説していきたいと思います.

各パターンの解説

ここからは先ほど列挙した各パターンのモデル保存方法を説明していきたいと思います.

前提情報としてモデル保存と読み込みを実装を交えて説明するためのモデルを以下に記載します.モデル自体はなんでも良いので適当に5層のMLPを定義しました.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 8)
        self.fc5 = nn.Linear(8, 2)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5)
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

model = MyModel()

state_dictのモデル保存

1番目の state_dict を用いたモデル保存について説明します.

state_dict はモデルで定義された各レイヤーにTensor形式のパラメーターをマッピングするための単純な辞書オブジェクトを返します.これによって,簡単にモデルを保存、更新、変更、復元できるようになります.

# モデル保存
torch.save(model.state_dict(), "models/model_state_dict.pth")

# モデル読み込み
model_state_dict = MyModel()
model_state_dict.load_state_dict(torch.load("models/model_state_dict.pth"), strict=False)

model_state_dict.state_dict() を実行すると学習ずみのパラメータが正しく読み込まれているのが分かります.

また, state_dictを用いることによってGPUで学習したモデルをCPUで推論したり,CPUで学習したモデルをGPUで推論したりすることができます.そのため,Pytorchの公式ドキュメントでは state_dict が推奨されています.

# GPU to CPU
# モデル保存
torch.save(model.state_dict(), "models/model_state_dict.pth")

# モデル読み込み
device = torch.device('cpu')
model = MyModel()
model.load_state_dict(torch.load("models/model_state_dict.pth", map_location=device))

# CPU to GPU
# モデル保存
torch.save(model.state_dict(), "models/model_state_dict.pth")

# モデル読み込み
device = torch.device("cuda")
model = MyModel()
model.load_state_dict(torch.load("models/model_state_dict.pth", map_location="cuda:0"))
model.to(device)

学習時と推論時で異なるプロセッサを利用する場合などはload関数を実行する際にmap_locationの引数にデバイスを指定してあげる必要があります.これによって,GPUで学習したモデルをCPUで読み込む際にはTensorを扱うストレージがCPUに動的再配置されるようになります.
逆に,CPUで学習したモデルをGPUde読み込む際も同様でload関数を実行する際にmap_locationの引数でTensorマッピングするGPUバイスを指定してあげる必要があります.ここで,GPUのどのデバイスに配置するのかを明示するためcuda:idの形式で指定してあげます.そして,CPUからGPUの場合では必ず model.to(torch.device('cuda')) を呼び出してモデルのパラメータであるTensorをCudaTensorに変換してあげてください.

entireのモデル保存

続いて,モデルとパラメータをセットで管理するモデル保存方法の説明です.

entire モデルではモデルに含まれるモジュール全体を保存することになります.メリットは少ないコード量で記述できることとその直感的な文法です.

# モデル保存
torch.save(model, "models/model_state_dict.pth")

# モデル読み込み
torch.load("models/model_state_dict.pth")

この方法では pickle モジュールを用いて保存されているおり,シリアル化されたデータが特定のクラスとディレクトリ構造にバインドされることになります.これによって,推論時のロードうやリファクタリングした場合に参照するパスが異なっていたりして致命的なエラーとなり利用できなくなります.このデメリットからも公式では非推奨となっています.
直感的で分かりやすいのですが,デメリットが大きすぎで実際に活用するシーンというのは少なそうです.

TorchScriptのモデル保存

まずTorchScriptについて簡単に説明します.

TorchScriptはPytorchで学習させたモデルをPython非依存な形でモデルを最適化してパラメータを保存することができます.この形式で保存したモデルはPython以外のC++iOSAndroidといった様々な環境で読み込んで利用することが可能になります.

こちらの記事で分かりやすくまとめられてますので,TorchScriptについて詳しく知りたいという方は是非参照してみてください.

pytorch.org

では,TorchScriptでモデルを保存する方法を以下に記載します.

# モデル保存
input_tensor = torch.rand(1, 128)
model_trace = torch.jit.trace(model, input_tensor)
model_trace.save('models/model_trace.pth')

# モデル読み込み
model_trace = torch.jit.trace('models/model_trace.pth')

TorchScript のモデルを作成するためにはサンプルデータを流して処理をトレースして変換を行うことで生成します.非常に簡単ですね.また,TorchScript のモデルを作成する方法にはトレースして生成する方法以外に直接的にモデルを記述する方法もあります.

べストな保存方法

さて,このパターンを踏まえてベストな保存方法について検討してみます.
まず,パターン2の entire のモデル保存ですがこちらは公式でも非推奨とされているので利用しない方が良いと思います.

次にパターン1の state_dictとパターン3の TorchScript のモデル保存ですが,こちらの2つは状況に合わせて使い分けを実施するのが良いです.
モデル作成途中やプロトタイプレベルで利用する場合にはパターン1のstate_dictによるモデル保存を利用します.TorchScriptでは最新の処理には対応していない場合があり,変換する際にエラー対応が必要になることがあります,そのため,試作レベルの時にはデバックに用する時間をかけるよりもスピード感を重要視して state_dict を利用する方が良いと思います.
そして,プロダクションフェーズの利用の際には TorchScript でのモデル保存を実施して最適化されたモデルを活用するといった使い分けが良いと言えます.TorchScriptも最近では改善や最適化が進んでいるので今後は全ての状況でTorchScriptでモデル保存する方が良いといった変化はあるかと思いますが,現状は上記のような使い方をするのがベストかと思います.

まとめ

今回はPytorchのモデル保存についてまとめてみました.
普段何気なく使っている人やTorchScriptを使ったことのない人にとって参考になればと思います.