-
파이썬 특정 weight만 Load카테고리 없음 2022. 8. 24. 14:18
student에 키가 존재하고 size 같을 때만 pretrained weight load하고 싶을 때
if cfg.MODEL.student_dir is not None: pretrained = torch.load(cfg.MODEL.student_dir) #pretrained=pretrained['teacher'].state_dict() student_dict = student.state_dict() pretrained_dict = {k: v for k, v in pretrained['teacher'].items() if (k in student_dict) and (student_dict[k].shape==pretrained['teacher'][k].shape)} student_dict.update(pretrained_dict) student.load_state_dict(student_dict)
키값 pretrained, student모두에 존재하고 size는 상관 없을 때
pretrained_dict = pretrained_model.state_dict() new_model_dict = new_model.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict} new_model_dict.update(pretrained_dict) new_model.load_state_dict(new_model_dict)
다 필요없고 strict = False로 주면 모델,ckpt에 모두 있는 키만 load함.
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)