模型训练开销

好久没写技博了,来一篇!我导提醒我们需要对模型训练的开销有一个基本的sense,不然枉为AI人。因此整理了本文。


深度学习模型训练的开销主要可以从两个维度来衡量:

  • 显存占用:决定能不能训练
  • 算力消耗:决定要训练多久

显存占用

在一个经典的反向传播随机梯度下降算法中,我们需要保存以下变量:

假设参数量为$\Psi$,参数量的计算相对简单;一个热知识:Transformer中的FFN不论是参数量还是计算量都占据整个模型很大的比重(超过一半,甚至90%+),是开销最大的地方。

  • 模型参数
    • 每个参数使用单精度浮点数fp16来存储
    • 显存占用:$2\Psi$ …