reshape_utils.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import os
  3. import torch
  4. from collections import OrderedDict
  5. from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)
  6. def basic_folder_validation(dir):
  7. assert os.path.exists(dir), f'{dir} path does not exist'
  8. assert os.path.isdir(dir), f'{dir} is not a folder'
  9. def get_files_with_prefix(all_files, prefix):
  10. file_list = []
  11. for file_path in all_files:
  12. _, fname = os.path.split(file_path)
  13. if fname.startswith(prefix):
  14. file_list.append(file_path)
  15. return sorted(file_list)
  16. def validate_files(file_list):
  17. for file in file_list:
  18. if not os.path.isfile(file):
  19. print(f'Error: {file} is not existent')
  20. def get_files(dir):
  21. file_list = []
  22. for root, _, files in os.walk(dir):
  23. for file in files:
  24. file_list.append(os.path.join(root, file))
  25. return file_list
  26. def get_zero_files(dir):
  27. file_list = get_files(dir)
  28. for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
  29. zero_files = get_files_with_prefix(file_list, prefix)
  30. if len(zero_files) > 0:
  31. return zero_files
  32. return []
  33. def partition_data(data_list, num_partitions):
  34. num_elems = len(data_list)
  35. assert num_elems % num_partitions == 0
  36. partition_size = num_elems // num_partitions
  37. partitions_list = [
  38. data_list[i:i + partition_size] for i in range(0,
  39. num_elems,
  40. partition_size)
  41. ]
  42. return partitions_list
  43. def _key_list_to_string(key_list):
  44. return '.'.join(key_list)
  45. def merge_state_dict(dict_a, dict_b, key_list):
  46. merged_dict = type(dict_a)({})
  47. for key, value in dict_b.items():
  48. if key in dict_a.keys():
  49. merged_dict[key] = merge_state(dict_a[key], dict_b[key], [str(key)])
  50. else:
  51. merged_dict[key] = value
  52. return merged_dict
  53. def merge_state_list(list_a, list_b, key_list):
  54. if len(list_a) != len(list_b):
  55. print(f'{_key_list_to_string(key_list)}')
  56. raise ValueError(
  57. f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}'
  58. )
  59. return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)]
  60. def merge_state(state_a, state_b, key_list=[]):
  61. if type(state_a) != type(state_b):
  62. key_list_string = _key_list_to_string(key_list)
  63. print(f'key_list = {key_list_string}')
  64. raise ValueError(
  65. f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
  66. if type(state_a) in (dict, OrderedDict):
  67. return merge_state_dict(state_a, state_b, key_list)
  68. elif type(state_a) in (list, tuple):
  69. return type(state_a)(merge_state_list(state_a, state_b, key_list))
  70. elif torch.is_tensor(state_a):
  71. return torch.cat([state_a, state_b], 0)
  72. else:
  73. return state_a