ckpt_utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import glob
  2. import os
  3. import re
  4. import torch
  5. def get_last_checkpoint(work_dir, steps=None):
  6. checkpoint = None
  7. last_ckpt_path = None
  8. if work_dir.endswith(".ckpt"):
  9. ckpt_paths = [work_dir]
  10. else:
  11. ckpt_paths = get_all_ckpts(work_dir, steps)
  12. if len(ckpt_paths) > 0:
  13. last_ckpt_path = ckpt_paths[0]
  14. checkpoint = torch.load(last_ckpt_path, map_location='cpu')
  15. return checkpoint, last_ckpt_path
  16. def get_all_ckpts(work_dir, steps=None):
  17. if steps is None:
  18. ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
  19. else:
  20. ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
  21. return sorted(glob.glob(ckpt_path_pattern),
  22. key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
  23. def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, steps=None, verbose=True):
  24. if os.path.isfile(ckpt_base_dir):
  25. base_dir = os.path.dirname(ckpt_base_dir)
  26. ckpt_path = ckpt_base_dir
  27. checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
  28. else:
  29. base_dir = ckpt_base_dir
  30. checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
  31. if checkpoint is not None:
  32. state_dict = checkpoint["state_dict"]
  33. if len([k for k in state_dict.keys() if '.' in k]) > 0:
  34. state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
  35. if k.startswith(f'{model_name}.')}
  36. else:
  37. if '.' not in model_name:
  38. state_dict = state_dict[model_name]
  39. else:
  40. base_model_name = model_name.split('.')[0]
  41. rest_model_name = model_name[len(base_model_name) + 1:]
  42. state_dict = {
  43. k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items()
  44. if k.startswith(f'{rest_model_name}.')}
  45. if not strict:
  46. cur_model_state_dict = cur_model.state_dict()
  47. unmatched_keys = []
  48. for key, param in state_dict.items():
  49. if key in cur_model_state_dict:
  50. new_param = cur_model_state_dict[key]
  51. if new_param.shape != param.shape:
  52. unmatched_keys.append(key)
  53. print("| Unmatched keys (shape mismatch): ", key, new_param.shape, param.shape)
  54. else:
  55. print(f"Skipping unmatched keys (in state_dict but not in cur_model): {key}")
  56. for key in unmatched_keys:
  57. if verbose:
  58. print(f"Del unmatched keys {key}")
  59. del state_dict[key]
  60. if hasattr(cur_model, 'load_state_dict'):
  61. cur_model.load_state_dict(state_dict, strict=strict)
  62. else: # when cur_model is nn.Parameter
  63. cur_model.data = state_dict
  64. print(f"| load '{model_name}' from '{ckpt_path}', strict={strict}")
  65. else:
  66. e_msg = f"| ckpt not found in {base_dir}."
  67. if force:
  68. assert False, e_msg
  69. else:
  70. print(e_msg)
  71. def restore_weights(task_ref, checkpoint):
  72. # load model state
  73. for k, v in checkpoint['state_dict'].items():
  74. if hasattr(task_ref, k):
  75. getattr(task_ref, k).load_state_dict(v, strict=True)
  76. print(f"| resotred {k} from pretrained checkpoints")
  77. else:
  78. print(f"| the checkpoint has unmatched keys {k}")
  79. def restore_opt_state(optimizers, checkpoint):
  80. # restore the optimizers
  81. optimizer_states = checkpoint['optimizer_states']
  82. for optimizer, opt_state in zip(optimizers, optimizer_states):
  83. if optimizer is None:
  84. return
  85. try:
  86. optimizer.load_state_dict(opt_state)
  87. # move optimizer to GPU 1 weight at a time
  88. # if self.on_gpu:
  89. # for state in optimizer.state.values():
  90. # for k, v in state.items():
  91. # if isinstance(v, torch.Tensor):
  92. # state[k] = v.cuda(self.root_gpu)
  93. except ValueError:
  94. print("| WARMING: optimizer parameters not match !!!")