config.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .constants import *
  5. import copy
  6. from ..runtime.config_utils import get_scalar_param, get_list_param
  7. def get_compression_config(param_dict):
  8. #
  9. output = {}
  10. if COMPRESSION_TRAINING not in param_dict.keys():
  11. param_dict[COMPRESSION_TRAINING] = {}
  12. sub_param_dict = param_dict[COMPRESSION_TRAINING]
  13. output[WEIGHT_QUANTIZATION] = get_weight_quantization(sub_param_dict)
  14. output[ACTIVATION_QUANTIZATION] = get_activation_quantization(sub_param_dict)
  15. output[SPARSE_PRUNING] = get_sparse_pruning(sub_param_dict)
  16. output[ROW_PRUNING] = get_row_pruning(sub_param_dict)
  17. output[HEAD_PRUNING] = get_head_pruning(sub_param_dict)
  18. output[CHANNEL_PRUNING] = get_channel_pruning(sub_param_dict)
  19. output[LAYER_REDUCTION] = get_layer_reduction(sub_param_dict)
  20. return output
  21. def get_layer_reduction(param_dict):
  22. output = {}
  23. output[LAYER_REDUCTION_ENABLED] = LAYER_REDUCTION_ENABLED_DEFAULT
  24. if get_layer_reduction_enabled(param_dict):
  25. output[LAYER_REDUCTION_ENABLED] = get_layer_reduction_enabled(param_dict)
  26. for key, val in get_layer_reduction_params(param_dict).items():
  27. output[key] = val
  28. return output
  29. def get_layer_reduction_enabled(param_dict):
  30. if LAYER_REDUCTION in param_dict.keys():
  31. return get_scalar_param(param_dict[LAYER_REDUCTION], LAYER_REDUCTION_ENABLED, LAYER_REDUCTION_ENABLED_DEFAULT)
  32. else:
  33. return False
  34. def get_layer_reduction_params(param_dict):
  35. if LAYER_REDUCTION in param_dict.keys():
  36. layer_reduction_params = copy.copy(param_dict[LAYER_REDUCTION])
  37. layer_reduction_params.pop(LAYER_REDUCTION_ENABLED)
  38. return layer_reduction_params
  39. else:
  40. return False
  41. def get_quantize_enabled(param_dict):
  42. if COMPRESSION_TRAINING not in param_dict.keys():
  43. return False
  44. sub_param_dict = param_dict[COMPRESSION_TRAINING]
  45. output = get_weight_quantization_shared_parameters(sub_param_dict)
  46. return output[WEIGHT_QUANTIZE_ENABLED]
  47. def get_weight_quantization(param_dict):
  48. output = {}
  49. if WEIGHT_QUANTIZATION not in param_dict.keys():
  50. param_dict[WEIGHT_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  51. sub_param_dict = param_dict[WEIGHT_QUANTIZATION]
  52. # shared parameters
  53. output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
  54. # each sub-groups
  55. if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
  56. assert DIFFERENT_GROUPS in sub_param_dict.keys(
  57. ), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
  58. output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
  59. return output
  60. def get_weight_quantization_shared_parameters(param_dict):
  61. output = {}
  62. if SHARED_PARAMETERS in param_dict.keys():
  63. sub_param_dict = param_dict[SHARED_PARAMETERS]
  64. output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ENABLED,
  65. WEIGHT_QUANTIZE_ENABLED_DEFAULT)
  66. output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_KERNEL,
  67. WEIGHT_QUANTIZE_KERNEL_DEFAULT)
  68. output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
  69. WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
  70. output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_GROUPS,
  71. WEIGHT_QUANTIZE_GROUPS_DEFAULT)
  72. output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_VERBOSE,
  73. WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
  74. output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_TYPE,
  75. WEIGHT_QUANTIZE_TYPE_DEFAULT)
  76. output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(sub_param_dict,
  77. WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
  78. WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
  79. assert output[WEIGHT_QUANTIZE_TYPE] in [
  80. WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC
  81. ], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
  82. output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ROUNDING,
  83. WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
  84. assert output[WEIGHT_QUANTIZE_ROUNDING] in [
  85. WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING
  86. ], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
  87. if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
  88. output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
  89. sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
  90. WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
  91. output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
  92. sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_CHANGE_RATIO,
  93. WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
  94. else:
  95. output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
  96. output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
  97. else:
  98. output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
  99. output[WEIGHT_QUANTIZE_KERNEL] = WEIGHT_QUANTIZE_KERNEL_DEFAULT
  100. output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
  101. output[WEIGHT_QUANTIZE_GROUPS] = WEIGHT_QUANTIZE_GROUPS_DEFAULT
  102. output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
  103. output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
  104. output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
  105. output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
  106. output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
  107. return output
  108. def get_weight_quantization_different_groups(param_dict):
  109. output = {}
  110. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  111. def get_params(name, group_dict):
  112. assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(
  113. ), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
  114. assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(
  115. ), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
  116. group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(group_dict, WEIGHT_QUANTIZATION_PERIOD,
  117. WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
  118. return group_dict
  119. for k, v in sub_param_dict.items():
  120. output[k] = {}
  121. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  122. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
  123. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  124. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  125. sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  126. return output
  127. def get_activation_quantization(param_dict):
  128. output = {}
  129. if ACTIVATION_QUANTIZATION not in param_dict.keys():
  130. param_dict[ACTIVATION_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  131. sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
  132. # shared parameters
  133. output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(sub_param_dict)
  134. # each sub-groups
  135. if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
  136. assert DIFFERENT_GROUPS in sub_param_dict.keys(
  137. ), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
  138. output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(sub_param_dict)
  139. return output
  140. def get_activation_quantization_shared_parameters(param_dict):
  141. output = {}
  142. if SHARED_PARAMETERS in param_dict.keys():
  143. sub_param_dict = param_dict[SHARED_PARAMETERS]
  144. output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZATION_ENABLED,
  145. ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
  146. output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_TYPE,
  147. ACTIVATION_QUANTIZE_TYPE_DEFAULT)
  148. assert output[ACTIVATION_QUANTIZE_TYPE] in [
  149. ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC
  150. ], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
  151. output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_RANGE,
  152. ACTIVATION_QUANTIZE_RANGE_DEFAULT)
  153. assert output[ACTIVATION_QUANTIZE_RANGE] in [
  154. ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC
  155. ], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
  156. output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict,
  157. ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
  158. ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
  159. else:
  160. output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
  161. output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
  162. output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
  163. output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
  164. return output
  165. def get_activation_quantization_different_groups(param_dict):
  166. output = {}
  167. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  168. def get_params(name, group_dict):
  169. assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(
  170. ), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
  171. return group_dict
  172. for k, v in sub_param_dict.items():
  173. output[k] = {}
  174. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  175. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
  176. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  177. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  178. sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  179. return output
  180. def get_sparse_pruning(param_dict):
  181. output = {}
  182. if SPARSE_PRUNING not in param_dict.keys():
  183. param_dict[SPARSE_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  184. sub_param_dict = param_dict[SPARSE_PRUNING]
  185. # shared parameters
  186. output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
  187. # each sub-groups
  188. if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED] and output[SHARED_PARAMETERS][
  189. SPARSE_PRUNING_METHOD] != SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
  190. assert DIFFERENT_GROUPS in sub_param_dict.keys(
  191. ), f"Sparse Pruning is enabled and not snip_momentum method, {DIFFERENT_GROUPS} must be specified"
  192. output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
  193. return output
  194. def get_sparse_pruning_shared_parameters(param_dict):
  195. output = {}
  196. if SHARED_PARAMETERS in param_dict.keys():
  197. sub_param_dict = param_dict[SHARED_PARAMETERS]
  198. output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
  199. SPARSE_PRUNING_ENABLED_DEFAULT)
  200. output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
  201. SPARSE_PRUNING_METHOD_DEFAULT)
  202. assert output[SPARSE_PRUNING_METHOD] in [
  203. SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK, SPARSE_PRUNING_METHOD_SNIP_MOMENTUM
  204. ], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}, {SPARSE_PRUNING_METHOD_SNIP_MOMENTUM}]"
  205. output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
  206. SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  207. if output[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
  208. output[SPARSE_PRUNING_BLOCK_PATTERN] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_BLOCK_PATTERN,
  209. SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT)
  210. output[SPARSE_PRUNING_DENSE_RATIO] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_DENSE_RATIO,
  211. SPARSE_PRUNING_DENSE_RATIO_DEFAULT)
  212. assert output[SPARSE_PRUNING_DENSE_RATIO] > 0 and output[
  213. SPARSE_PRUNING_DENSE_RATIO] < 1, f"Invalid dense_ratio value. Must be less than 1"
  214. output[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE] = get_scalar_param(
  215. sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT)
  216. output[SPARSE_PRUNING_EXCLUDED_MODULES] = get_list_param(sub_param_dict, SPARSE_PRUNING_EXCLUDED_MODULES,
  217. SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT)
  218. output[SPARSE_PRUNING_SCHEDULE_OFFSET_END] = get_scalar_param(sub_param_dict,
  219. SPARSE_PRUNING_SCHEDULE_OFFSET_END,
  220. output[SPARSE_PRUNING_SCHEDULE_OFFSET])
  221. assert output[SPARSE_PRUNING_SCHEDULE_OFFSET] <= output[
  222. SPARSE_PRUNING_SCHEDULE_OFFSET_END], f"Invalid schedule_offset and schedule_offset_end values"
  223. else:
  224. output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
  225. output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
  226. output[SPARSE_PRUNING_SCHEDULE_OFFSET] = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT
  227. return output
  228. def get_sparse_pruning_different_groups(param_dict):
  229. output = {}
  230. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  231. def get_params(name, group_dict):
  232. assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(
  233. ), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
  234. return group_dict
  235. for k, v in sub_param_dict.items():
  236. output[k] = {}
  237. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  238. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
  239. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  240. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  241. sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  242. return output
  243. def get_row_pruning(param_dict):
  244. output = {}
  245. if ROW_PRUNING not in param_dict.keys():
  246. param_dict[ROW_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  247. sub_param_dict = param_dict[ROW_PRUNING]
  248. # shared parameters
  249. output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
  250. # each sub-groups
  251. if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
  252. assert DIFFERENT_GROUPS in sub_param_dict.keys(
  253. ), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  254. output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
  255. return output
  256. def get_row_pruning_shared_parameters(param_dict):
  257. output = {}
  258. if SHARED_PARAMETERS in param_dict.keys():
  259. sub_param_dict = param_dict[SHARED_PARAMETERS]
  260. output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, ROW_PRUNING_ENABLED,
  261. ROW_PRUNING_ENABLED_DEFAULT)
  262. output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict, ROW_PRUNING_METHOD, ROW_PRUNING_METHOD_DEFAULT)
  263. assert output[ROW_PRUNING_METHOD] in [
  264. ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK
  265. ], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
  266. output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, ROW_PRUNING_SCHEDULE_OFFSET,
  267. ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  268. else:
  269. output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
  270. output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
  271. output[ROW_PRUNING_SCHEDULE_OFFSET] = ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT
  272. return output
  273. def get_row_pruning_different_groups(param_dict):
  274. output = {}
  275. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  276. def get_params(name, group_dict):
  277. assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(
  278. ), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
  279. return group_dict
  280. for k, v in sub_param_dict.items():
  281. output[k] = {}
  282. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  283. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
  284. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  285. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  286. sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  287. return output
  288. def get_head_pruning(param_dict):
  289. output = {}
  290. if HEAD_PRUNING not in param_dict.keys():
  291. param_dict[HEAD_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  292. sub_param_dict = param_dict[HEAD_PRUNING]
  293. # shared parameters
  294. output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
  295. # each sub-groups
  296. if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
  297. assert DIFFERENT_GROUPS in sub_param_dict.keys(
  298. ), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  299. output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
  300. return output
  301. def get_head_pruning_shared_parameters(param_dict):
  302. output = {}
  303. if SHARED_PARAMETERS in param_dict.keys():
  304. sub_param_dict = param_dict[SHARED_PARAMETERS]
  305. output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, HEAD_PRUNING_ENABLED,
  306. HEAD_PRUNING_ENABLED_DEFAULT)
  307. output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict, HEAD_PRUNING_METHOD,
  308. HEAD_PRUNING_METHOD_DEFAULT)
  309. assert output[HEAD_PRUNING_METHOD] in [
  310. HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK
  311. ], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
  312. output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, HEAD_PRUNING_SCHEDULE_OFFSET,
  313. HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  314. if output[HEAD_PRUNING_ENABLED]:
  315. assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(
  316. ), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
  317. output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
  318. else:
  319. output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
  320. output[HEAD_PRUNING_METHOD] = HEAD_PRUNING_METHOD_DEFAULT
  321. output[HEAD_PRUNING_SCHEDULE_OFFSET] = HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT
  322. return output
  323. def get_head_pruning_different_groups(param_dict):
  324. output = {}
  325. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  326. def get_params(name, group_dict):
  327. assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(
  328. ), f"dense_ratio must be specified for head pruning group {name}"
  329. return group_dict
  330. for k, v in sub_param_dict.items():
  331. output[k] = {}
  332. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  333. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
  334. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  335. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  336. sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  337. return output
  338. def get_channel_pruning(param_dict):
  339. output = {}
  340. if CHANNEL_PRUNING not in param_dict.keys():
  341. param_dict[CHANNEL_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  342. sub_param_dict = param_dict[CHANNEL_PRUNING]
  343. # shared parameters
  344. output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
  345. # each sub-groups
  346. if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
  347. assert DIFFERENT_GROUPS in sub_param_dict.keys(
  348. ), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  349. output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
  350. return output
  351. def get_channel_pruning_shared_parameters(param_dict):
  352. output = {}
  353. if SHARED_PARAMETERS in param_dict.keys():
  354. sub_param_dict = param_dict[SHARED_PARAMETERS]
  355. output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_ENABLED,
  356. CHANNEL_PRUNING_ENABLED_DEFAULT)
  357. output[CHANNEL_PRUNING_METHOD] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_METHOD,
  358. CHANNEL_PRUNING_METHOD_DEFAULT)
  359. assert output[CHANNEL_PRUNING_METHOD] in [
  360. CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK
  361. ], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
  362. output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_SCHEDULE_OFFSET,
  363. CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  364. else:
  365. output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
  366. output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
  367. output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT
  368. return output
  369. def get_channel_pruning_different_groups(param_dict):
  370. output = {}
  371. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  372. def get_params(name, group_dict):
  373. assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(
  374. ), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
  375. return group_dict
  376. for k, v in sub_param_dict.items():
  377. output[k] = {}
  378. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  379. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
  380. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  381. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  382. sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  383. return output