ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 파이썬 특정 weight만 Load
    카테고리 없음 2022. 8. 24. 14:18

    student에 키가 존재하고 size 같을 때만 pretrained weight  load하고 싶을 때

    if cfg.MODEL.student_dir is not None:
        pretrained = torch.load(cfg.MODEL.student_dir)
        #pretrained=pretrained['teacher'].state_dict()
        student_dict = student.state_dict()
        pretrained_dict = {k: v for k, v in pretrained['teacher'].items() if (k in student_dict) and (student_dict[k].shape==pretrained['teacher'][k].shape)}
        student_dict.update(pretrained_dict)
        student.load_state_dict(student_dict)

     

    키값 pretrained, student모두에 존재하고 size는 상관 없을 때 

    pretrained_dict = pretrained_model.state_dict()
    new_model_dict = new_model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict}
    new_model_dict.update(pretrained_dict)
    new_model.load_state_dict(new_model_dict)

     

    다 필요없고 strict = False로 주면 모델,ckpt에 모두 있는 키만 load함.

        g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
Designed by Tistory.