config.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .constants import *
  3. import copy
  4. from ..runtime.config_utils import get_scalar_param
  5. def get_compression_config(param_dict):
  6. #
  7. output = {}
  8. if COMPRESSION_TRAINING not in param_dict.keys():
  9. param_dict[COMPRESSION_TRAINING] = {}
  10. sub_param_dict = param_dict[COMPRESSION_TRAINING]
  11. output[WEIGHT_QUANTIZATION] = get_weight_quantization(sub_param_dict)
  12. output[ACTIVATION_QUANTIZATION] = get_activation_quantization(sub_param_dict)
  13. output[SPARSE_PRUNING] = get_sparse_pruning(sub_param_dict)
  14. output[ROW_PRUNING] = get_row_pruning(sub_param_dict)
  15. output[HEAD_PRUNING] = get_head_pruning(sub_param_dict)
  16. output[CHANNEL_PRUNING] = get_channel_pruning(sub_param_dict)
  17. output[LAYER_REDUCTION] = get_layer_reduction(sub_param_dict)
  18. return output
  19. def get_layer_reduction(param_dict):
  20. output = {}
  21. output[LAYER_REDUCTION_ENABLED] = LAYER_REDUCTION_ENABLED_DEFAULT
  22. if get_layer_reduction_enabled(param_dict):
  23. output[LAYER_REDUCTION_ENABLED] = get_layer_reduction_enabled(param_dict)
  24. for key, val in get_layer_reduction_params(param_dict).items():
  25. output[key] = val
  26. return output
  27. def get_layer_reduction_enabled(param_dict):
  28. if LAYER_REDUCTION in param_dict.keys():
  29. return get_scalar_param(param_dict[LAYER_REDUCTION],
  30. LAYER_REDUCTION_ENABLED,
  31. LAYER_REDUCTION_ENABLED_DEFAULT)
  32. else:
  33. return False
  34. def get_layer_reduction_params(param_dict):
  35. if LAYER_REDUCTION in param_dict.keys():
  36. layer_reduction_params = copy.copy(param_dict[LAYER_REDUCTION])
  37. layer_reduction_params.pop(LAYER_REDUCTION_ENABLED)
  38. return layer_reduction_params
  39. else:
  40. return False
  41. def get_quantize_enabled(param_dict):
  42. if COMPRESSION_TRAINING not in param_dict.keys():
  43. return False
  44. sub_param_dict = param_dict[COMPRESSION_TRAINING]
  45. output = get_weight_quantization_shared_parameters(sub_param_dict)
  46. return output[WEIGHT_QUANTIZE_ENABLED]
  47. def get_weight_quantization(param_dict):
  48. output = {}
  49. if WEIGHT_QUANTIZATION not in param_dict.keys():
  50. param_dict[WEIGHT_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  51. sub_param_dict = param_dict[WEIGHT_QUANTIZATION]
  52. # shared parameters
  53. output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
  54. # each sub-groups
  55. if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
  56. assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
  57. output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
  58. return output
  59. def get_weight_quantization_shared_parameters(param_dict):
  60. output = {}
  61. if SHARED_PARAMETERS in param_dict.keys():
  62. sub_param_dict = param_dict[SHARED_PARAMETERS]
  63. output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(
  64. sub_param_dict,
  65. WEIGHT_QUANTIZE_ENABLED,
  66. WEIGHT_QUANTIZE_ENABLED_DEFAULT)
  67. output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(
  68. sub_param_dict,
  69. WEIGHT_QUANTIZE_KERNEL,
  70. WEIGHT_QUANTIZE_KERNEL_DEFAULT)
  71. output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
  72. sub_param_dict,
  73. WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
  74. WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
  75. output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(
  76. sub_param_dict,
  77. WEIGHT_QUANTIZE_GROUPS,
  78. WEIGHT_QUANTIZE_GROUPS_DEFAULT)
  79. output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(
  80. sub_param_dict,
  81. WEIGHT_QUANTIZE_VERBOSE,
  82. WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
  83. output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict,
  84. WEIGHT_QUANTIZE_TYPE,
  85. WEIGHT_QUANTIZE_TYPE_DEFAULT)
  86. output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(
  87. sub_param_dict,
  88. WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
  89. WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
  90. assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
  91. output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(
  92. sub_param_dict,
  93. WEIGHT_QUANTIZE_ROUNDING,
  94. WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
  95. assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
  96. if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
  97. output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
  98. sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
  99. WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
  100. WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
  101. output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
  102. sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
  103. WEIGHT_QUANTIZE_CHANGE_RATIO,
  104. WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
  105. else:
  106. output[
  107. WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
  108. output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
  109. else:
  110. output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
  111. output[WEIGHT_QUANTIZE_KERNEL] = WEIGHT_QUANTIZE_KERNEL_DEFAULT
  112. output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
  113. output[WEIGHT_QUANTIZE_GROUPS] = WEIGHT_QUANTIZE_GROUPS_DEFAULT
  114. output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
  115. output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
  116. output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
  117. output[
  118. WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
  119. output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
  120. return output
  121. def get_weight_quantization_different_groups(param_dict):
  122. output = {}
  123. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  124. def get_params(name, group_dict):
  125. assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
  126. assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
  127. group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(
  128. group_dict,
  129. WEIGHT_QUANTIZATION_PERIOD,
  130. WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
  131. return group_dict
  132. for k, v in sub_param_dict.items():
  133. output[k] = {}
  134. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
  135. k,
  136. sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  137. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
  138. sub_param_dict[k],
  139. DIFFERENT_GROUPS_MODULE_SCOPE,
  140. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  141. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  142. sub_param_dict[k],
  143. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
  144. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  145. return output
  146. def get_activation_quantization(param_dict):
  147. output = {}
  148. if ACTIVATION_QUANTIZATION not in param_dict.keys():
  149. param_dict[ACTIVATION_QUANTIZATION] = {
  150. SHARED_PARAMETERS: {},
  151. DIFFERENT_GROUPS: {}
  152. }
  153. sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
  154. # shared parameters
  155. output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(
  156. sub_param_dict)
  157. # each sub-groups
  158. if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
  159. assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
  160. output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(
  161. sub_param_dict)
  162. return output
  163. def get_activation_quantization_shared_parameters(param_dict):
  164. output = {}
  165. if SHARED_PARAMETERS in param_dict.keys():
  166. sub_param_dict = param_dict[SHARED_PARAMETERS]
  167. output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(
  168. sub_param_dict,
  169. ACTIVATION_QUANTIZATION_ENABLED,
  170. ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
  171. output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(
  172. sub_param_dict,
  173. ACTIVATION_QUANTIZE_TYPE,
  174. ACTIVATION_QUANTIZE_TYPE_DEFAULT)
  175. assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
  176. output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(
  177. sub_param_dict,
  178. ACTIVATION_QUANTIZE_RANGE,
  179. ACTIVATION_QUANTIZE_RANGE_DEFAULT)
  180. assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
  181. output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
  182. sub_param_dict,
  183. ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
  184. ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
  185. else:
  186. output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
  187. output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
  188. output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
  189. output[
  190. ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
  191. return output
  192. def get_activation_quantization_different_groups(param_dict):
  193. output = {}
  194. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  195. def get_params(name, group_dict):
  196. assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
  197. return group_dict
  198. for k, v in sub_param_dict.items():
  199. output[k] = {}
  200. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
  201. k,
  202. sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  203. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
  204. sub_param_dict[k],
  205. DIFFERENT_GROUPS_MODULE_SCOPE,
  206. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  207. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  208. sub_param_dict[k],
  209. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
  210. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  211. return output
  212. def get_sparse_pruning(param_dict):
  213. output = {}
  214. if SPARSE_PRUNING not in param_dict.keys():
  215. param_dict[SPARSE_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  216. sub_param_dict = param_dict[SPARSE_PRUNING]
  217. # shared parameters
  218. output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
  219. # each sub-groups
  220. if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
  221. assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  222. output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
  223. return output
  224. def get_sparse_pruning_shared_parameters(param_dict):
  225. output = {}
  226. if SHARED_PARAMETERS in param_dict.keys():
  227. sub_param_dict = param_dict[SHARED_PARAMETERS]
  228. output[SPARSE_PRUNING_ENABLED] = get_scalar_param(
  229. sub_param_dict,
  230. SPARSE_PRUNING_ENABLED,
  231. SPARSE_PRUNING_ENABLED_DEFAULT)
  232. output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
  233. SPARSE_PRUNING_METHOD,
  234. SPARSE_PRUNING_METHOD_DEFAULT)
  235. assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
  236. output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
  237. sub_param_dict,
  238. SPARSE_PRUNING_SCHEDULE_OFFSET,
  239. SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  240. else:
  241. output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
  242. output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
  243. output[SPARSE_PRUNING_SCHEDULE_OFFSET] = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT
  244. return output
  245. def get_sparse_pruning_different_groups(param_dict):
  246. output = {}
  247. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  248. def get_params(name, group_dict):
  249. assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
  250. return group_dict
  251. for k, v in sub_param_dict.items():
  252. output[k] = {}
  253. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
  254. k,
  255. sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  256. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
  257. sub_param_dict[k],
  258. DIFFERENT_GROUPS_MODULE_SCOPE,
  259. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  260. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  261. sub_param_dict[k],
  262. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
  263. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  264. return output
  265. def get_row_pruning(param_dict):
  266. output = {}
  267. if ROW_PRUNING not in param_dict.keys():
  268. param_dict[ROW_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  269. sub_param_dict = param_dict[ROW_PRUNING]
  270. # shared parameters
  271. output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
  272. # each sub-groups
  273. if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
  274. assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  275. output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
  276. return output
  277. def get_row_pruning_shared_parameters(param_dict):
  278. output = {}
  279. if SHARED_PARAMETERS in param_dict.keys():
  280. sub_param_dict = param_dict[SHARED_PARAMETERS]
  281. output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
  282. ROW_PRUNING_ENABLED,
  283. ROW_PRUNING_ENABLED_DEFAULT)
  284. output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
  285. ROW_PRUNING_METHOD,
  286. ROW_PRUNING_METHOD_DEFAULT)
  287. assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
  288. output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
  289. sub_param_dict,
  290. ROW_PRUNING_SCHEDULE_OFFSET,
  291. ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  292. else:
  293. output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
  294. output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
  295. output[ROW_PRUNING_SCHEDULE_OFFSET] = ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT
  296. return output
  297. def get_row_pruning_different_groups(param_dict):
  298. output = {}
  299. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  300. def get_params(name, group_dict):
  301. assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
  302. return group_dict
  303. for k, v in sub_param_dict.items():
  304. output[k] = {}
  305. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
  306. k,
  307. sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  308. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
  309. sub_param_dict[k],
  310. DIFFERENT_GROUPS_MODULE_SCOPE,
  311. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  312. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  313. sub_param_dict[k],
  314. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
  315. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  316. return output
  317. def get_head_pruning(param_dict):
  318. output = {}
  319. if HEAD_PRUNING not in param_dict.keys():
  320. param_dict[HEAD_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  321. sub_param_dict = param_dict[HEAD_PRUNING]
  322. # shared parameters
  323. output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
  324. # each sub-groups
  325. if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
  326. assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  327. output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
  328. return output
  329. def get_head_pruning_shared_parameters(param_dict):
  330. output = {}
  331. if SHARED_PARAMETERS in param_dict.keys():
  332. sub_param_dict = param_dict[SHARED_PARAMETERS]
  333. output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
  334. HEAD_PRUNING_ENABLED,
  335. HEAD_PRUNING_ENABLED_DEFAULT)
  336. output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
  337. HEAD_PRUNING_METHOD,
  338. HEAD_PRUNING_METHOD_DEFAULT)
  339. assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
  340. output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
  341. sub_param_dict,
  342. HEAD_PRUNING_SCHEDULE_OFFSET,
  343. HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  344. if output[HEAD_PRUNING_ENABLED]:
  345. assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
  346. output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
  347. else:
  348. output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
  349. output[HEAD_PRUNING_METHOD] = HEAD_PRUNING_METHOD_DEFAULT
  350. output[HEAD_PRUNING_SCHEDULE_OFFSET] = HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT
  351. return output
  352. def get_head_pruning_different_groups(param_dict):
  353. output = {}
  354. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  355. def get_params(name, group_dict):
  356. assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}"
  357. return group_dict
  358. for k, v in sub_param_dict.items():
  359. output[k] = {}
  360. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
  361. k,
  362. sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  363. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
  364. sub_param_dict[k],
  365. DIFFERENT_GROUPS_MODULE_SCOPE,
  366. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  367. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  368. sub_param_dict[k],
  369. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
  370. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  371. return output
  372. def get_channel_pruning(param_dict):
  373. output = {}
  374. if CHANNEL_PRUNING not in param_dict.keys():
  375. param_dict[CHANNEL_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
  376. sub_param_dict = param_dict[CHANNEL_PRUNING]
  377. # shared parameters
  378. output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
  379. # each sub-groups
  380. if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
  381. assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
  382. output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
  383. return output
  384. def get_channel_pruning_shared_parameters(param_dict):
  385. output = {}
  386. if SHARED_PARAMETERS in param_dict.keys():
  387. sub_param_dict = param_dict[SHARED_PARAMETERS]
  388. output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(
  389. sub_param_dict,
  390. CHANNEL_PRUNING_ENABLED,
  391. CHANNEL_PRUNING_ENABLED_DEFAULT)
  392. output[CHANNEL_PRUNING_METHOD] = get_scalar_param(
  393. sub_param_dict,
  394. CHANNEL_PRUNING_METHOD,
  395. CHANNEL_PRUNING_METHOD_DEFAULT)
  396. assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
  397. output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
  398. sub_param_dict,
  399. CHANNEL_PRUNING_SCHEDULE_OFFSET,
  400. CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
  401. else:
  402. output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
  403. output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
  404. output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT
  405. return output
  406. def get_channel_pruning_different_groups(param_dict):
  407. output = {}
  408. sub_param_dict = param_dict[DIFFERENT_GROUPS]
  409. def get_params(name, group_dict):
  410. assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
  411. return group_dict
  412. for k, v in sub_param_dict.items():
  413. output[k] = {}
  414. output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
  415. k,
  416. sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
  417. output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
  418. sub_param_dict[k],
  419. DIFFERENT_GROUPS_MODULE_SCOPE,
  420. DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
  421. output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
  422. sub_param_dict[k],
  423. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
  424. DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
  425. return output