目标

将一个已定义/训练好的网络的部分层权重拆分出来,扔进新的网络中

分析

一个神经网络可以不断细分子网络,最后能拆分到一个个层。

模型实际上是一个字典,key为子网络.子网络的子网络.层.权重,value就是权重具体数值

如果想将模型的部分层从网络中拆分出来,最好复现需要提取权重部分的网络定义,这样就可以通过

source_dict = torch.load('path to pth')

target_model = model()
target_model_dict = target_model.state_dict()

overlap = {k:v for k, v in source_dict.items() if k in target_model_dict}

target_model_dict.update(overlap)
target_model.load_state_dict(target_model_dict)

直接将同key网络权重搞到手

如果觉得我的文章对你有用,请随意赞赏