class Char3SeqModel(nn.Module):
def __init__(self, char_sz, n_fac, n_h):
super().__init__()
self.em = nn.Embedding(char_sz, n_fac)
self.fc1 = nn.Linear(n_fac, n_h)
self.fc2 = nn.Linear(n_h, n_h)
self.fc3 = nn.Linear(n_h, char_sz)
def forward(self, ch1, ch2, ch3):
# do something
out = #....
return out
model = Char3SeqModel(10000, 50, 25)
# 冻结
model.fc1.weight.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
#
# compute loss
# loss.backward()
# optmizer.step()
# 解冻
model.fc1.weight.requires_grad = True
optimizer.add_param_group({'params': model.fc1.parameters()})