【読書記録】Deep Learning With PyTorch
本の概要
PyTorchの基礎~医用画像解析プロジェクトまでをカバーしている。
以前無料でPDFが配布されていて、そこで入手していた。(もちろん公式のモノ)
コード・実行環境
(著者のコードを写経したり少しいじったり、モジュール化して使えるようにしただけなので掲載しない)
実行環境は簡易的に書くと
Ubuntu18.04 (on WSL2)
Geforce RTX2070(8GB)
Docker 20.10.3
CUDA11.0
Python3.8.0
Pytorch1.7.1等
GPUはもともとゲーミング用途で購入した。そのためWindows環境しかなく、デュアルブートするのも怖いのでWSL2を使ってUbuntuでの開発環境を整えている。WSL2の進化は本当にすごくて、CUDAに対応したため、ローカルの計算資源もUbuntuから使うことができる。Dockerも使えるのでホスト環境を汚さずに済む。ただしWSL2の仕様上やたらにメモリを食うので、しばしば学習が落ちてしまい、本の内容を全ては実行できなかった。設定すればメモリ使用量制限できたりするらしいのでやらねば。。一部計算をColabで補ったりした。
[余談]
正直Windowsの有難みはゲームとOfficeくらいなので、Steam(ゲームのプラットフォーム)がLinuxに対応してくれたらもうWindow要らないんだよなぁ...。WSL凄いけど当然しばしば問題も起こるので、Linux単機を構築したいところ...。
読んだ目的
KaggleなどでPyTorchは使ったことがあるが、しっかり分かっているとはお世辞にも言えなかった。また、込み入ったプロジェクトにおける綺麗な実装例を学びたかった。そしてなにより無料であったため。
読んだ後の実感
PyTorchのTensorがメモリ上でどう管理されているかといった部分など、普段表面的に使っているだけでは意識しないようなことも学べた。総じてPyTorch力が上がった気がする。
また、CT画像を用いて肺がん検出を行うHandsOnがあり、これも非常に勉強になった。CT画像の扱い方や座標変換方法といった要素も学べたが、それ以上にプロジェクト全体をどう切り分けてコードを書いていけばよいのかが示されていた。
今回であれば、
- DataLoad
- Segmentationによる結節位置検出
- 結節のGrouping
- 分類モデルによる結節候補の分類
- 悪性/良性の診断
という段階に分けることができ、それぞれにおいて再利用性の高いコードを書く経験ができた。実装が非常に綺麗で、とても参考になった。ログをいつ出力するべきか、全体のボトルネックになる場所はどこか、高価な計算資源を最大限活用できているか、などのPracticeも学んだ。
評価指標の考え方なども参考になった。医療現場では陽性を見落とす(False-Negative)と大変なことになるため、偽陰性を減らすように方針を決め、TensorBoardなどで学習を可視化しながら、色々いじっていく、という流れが体験できた。TensorBoardはチュートリアルレベルしかやったことがなかったので、便利だなぁと思った。使いこなすのは大変そう。
最終章ではおまけ程度だがモデルのデプロイなども学べた。Flaskを用いてWebサーバーにデプロイしたり、C++でPyTorchを使って高速化したり、JIT化による恩恵を得たり、などである。最後のAndroidへのデプロイは、Javaの知識が無いため飛ばしてしまった。エッジ向けのモデルの軽量化や、DNN向けハードウェアに渡す際にONNXというフォーマットが用いられることが多いことを学んだ。ここら辺のエッジコンピューティングやハードウェアと協調してアルゴリズムを考えていくあたりがとても興味があるため、今後もこのあたりを学びたいと思った。
本の内容メモ(自分用)
第1部
Pytorchの環境構築
Pytorchの基礎
Pre-trainedモデルの使い方
Tensorの内部実装
- Tensorは初期化時にStorageに要素が格納され、メモリアドレスのOffsetと次元ごとのindexの幅(Stride)の情報を持つ。これによりコピーし直すことなくShapeに応じたアクセスが可能になる。これをviewといって、転置などの演算もStorageをいじることなく可能になる。
- Inplaceな操作はメソッド名に_が付く
Tensorで色々なデータを表現してみる
線形回帰を敢えて勾配法でやってみる
線形回帰で十分なデータに対して、敢えてニューラルネットワークを適用してみる
画像データに対して基本的なCNNを構築
3次元画像データの扱い方
第2部
CT画像から肺がんを早期検知するプロジェクトの構成
- DataLoad
- Segmentationによる結節位置検出
- 結節のGrouping
- 分類モデルによる結節候補の分類
- 悪性/良性の診断
に分けて実装
CT画像は患者座標系からindex-row-column座標系へ変換して扱う
AnnotationDataも盛り込む
分類モデルは簡単なCNN
SegmentationはU-net(
[1505.04597] U-Net: Convolutional Networks for Biomedical Image Segmentation
)を使用。
分類は3次元画像のまま扱い、Segmentationは3次元画像として構成せずに2次元画像を多Channelにしたものとして扱う。空間的な情報は抜け落ちるので、モデルにそこも学習させてあげる必要があるが、3次元でSegmentationを行うよりはコストが小さい。
[性能向上手段]
有効なData AugmentationとしてElastic Deformation(画像にDistortionを掛けるやつ?)や、Mixup(入力として、複数の画像が混合されたものを入れて予測を安定させる)
マルチタスク学習(評価する出力以外の出力も追加で学習することで、目的のタスクに対する性能が向上する可能性がある)
第3部
Pytorch JIT (要チェック)
Pythonの並列化にはGIL(Global Interpreter Lock)がネックになる。Python環境でモデルを実行しなれければこの影響を受けることはない。
PytorchはC++で書かれた内部ライブラリに演算を投げている。
Python, C++の両方からJIT関数にアクセスでき、そこかC++ LibTorchが呼ばれ、ATen(Tensor計算ライブラリ)とバックエンド達が呼ばれる。
Pytorch C++でJIT化したモデルを呼び出せば、Pythonを介さずにできる。もちろんC++で書くのもok.
C++のPytorchはちゃんと勉強したい
モバイルへのデプロイ(TODO)
モデルの軽量化
- 枝狩り
- 量子化(特に int8がよく使われる)
量子化は、ビット数減少による丸め誤差をランダムなものと考え、畳み込みと線形層を重み付き平均のようにとらえれば、丸め誤差は相殺される。中心極限定理。
floatからintへ変換することで、固定精度になる。これは重みの微小な変化を無視することが期待できる。ある種正則化のような感じ。