123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from .compress import get_module_name
- from .constants import *
- from .helper import recursive_getattr
- from deepspeed.utils import logger
- class compression_scheduler():
- '''
- Used to schedule different compression methods
- '''
- def __init__(self, model, compression_config):
- self.model = model
- self.compression_config = compression_config
- self.make_init()
- self.training_steps = 0
- self.weight_quantization_enabled = False
- self.verbose = {
- WEIGHT_QUANTIZATION: False,
- ACTIVATION_QUANTIZATION: False,
- SPARSE_PRUNING: False,
- HEAD_PRUNING: False,
- ROW_PRUNING: False,
- CHANNEL_PRUNING: False
- }
- def make_init(self):
- self.different_compression_methods = {}
- for method, method_content in self.compression_config.items():
- if LAYER_REDUCTION in method:
- continue
- self.different_compression_methods[method] = {
- TECHNIQUE_ENABLED: False,
- SHARED_PARAMETERS: None,
- DIFFERENT_GROUPS: []
- }
- exist_module_name = set()
- shared_parameters = method_content[SHARED_PARAMETERS]
- self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
- self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters
- for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
- module_name_list = []
- for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
- module_name, exist_module_name = get_module_name(group_name,
- self.model,
- key_word,
- exist_module_name,
- verbose=False)
- module_name_list.extend(module_name)
- if module_name_list:
- self.different_compression_methods[method][DIFFERENT_GROUPS].append(
- [group_name, module_name_list,
- method_parameters.copy().pop('params')])
- def check_weight_quantization(self):
- # check weight quantization
- wq = self.different_compression_methods[WEIGHT_QUANTIZATION]
- if not wq[TECHNIQUE_ENABLED]:
- return
- else:
- shared_parameters = wq[SHARED_PARAMETERS]
- if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
- for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]:
- for module_name in module_name_list:
- module = recursive_getattr(self.model, module_name)
- module.weight_quantization_enabled = True
- if not self.verbose[WEIGHT_QUANTIZATION]:
- logger.info(f'Weight quantization is enabled at step {self.training_steps}')
- self.weight_quantization_enabled = True
- self.verbose[WEIGHT_QUANTIZATION] = True
- def check_activation_quantization(self):
- # check activation quantization
- aq = self.different_compression_methods[ACTIVATION_QUANTIZATION]
- if not aq[TECHNIQUE_ENABLED]:
- return
- else:
- shared_parameters = aq[SHARED_PARAMETERS]
- if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
- for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]:
- for module_name in module_name_list:
- module = recursive_getattr(self.model, module_name)
- module.activation_quantization_enabled = True
- if not self.verbose[ACTIVATION_QUANTIZATION]:
- logger.info(f'Activation quantization is enabled at step {self.training_steps}')
- self.verbose[ACTIVATION_QUANTIZATION] = True
- def check_sparse_pruning(self):
- # check sparse pruning
- sp = self.different_compression_methods[SPARSE_PRUNING]
- if not sp[TECHNIQUE_ENABLED]:
- return
- else:
- shared_parameters = sp[SHARED_PARAMETERS]
- if shared_parameters[TECHNIQUE_SCHEDULE_OFFSET] <= self.training_steps <= shared_parameters[
- TECHNIQUE_SCHEDULE_OFFSET_END]:
- for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
- for module_name in module_name_list:
- module = recursive_getattr(self.model, module_name)
- module.sparse_pruning_enabled = True
- if not self.verbose[SPARSE_PRUNING]:
- logger.info(f'Sparse pruning is enabled at step {self.training_steps}')
- self.verbose[SPARSE_PRUNING] = True
- def check_head_pruning(self):
- # check head pruning
- hp = self.different_compression_methods[HEAD_PRUNING]
- if not hp[TECHNIQUE_ENABLED]:
- return
- else:
- shared_parameters = hp[SHARED_PARAMETERS]
- if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
- for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]:
- for module_name in module_name_list:
- module = recursive_getattr(self.model, module_name)
- module.head_pruning_enabled = True
- if not self.verbose[HEAD_PRUNING]:
- logger.info(f'Head pruning is enabled at step {self.training_steps}')
- self.verbose[HEAD_PRUNING] = True
- def check_row_pruning(self):
- # check row pruning
- rp = self.different_compression_methods[ROW_PRUNING]
- if not rp[TECHNIQUE_ENABLED]:
- return
- else:
- shared_parameters = rp[SHARED_PARAMETERS]
- if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
- for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]:
- for module_name in module_name_list:
- module = recursive_getattr(self.model, module_name)
- module.row_pruning_enabled = True
- if not self.verbose[ROW_PRUNING]:
- logger.info(f'Row pruning is enabled at step {self.training_steps}')
- self.verbose[ROW_PRUNING] = True
- def check_channel_pruning(self):
- # check channel pruning
- cp = self.different_compression_methods[CHANNEL_PRUNING]
- if not cp[TECHNIQUE_ENABLED]:
- return
- else:
- shared_parameters = cp[SHARED_PARAMETERS]
- if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
- for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]:
- for module_name in module_name_list:
- module = recursive_getattr(self.model, module_name)
- module.channel_pruning_enabled = True
- if not self.verbose[CHANNEL_PRUNING]:
- logger.info(f'Channel pruning is enabled at step {self.training_steps}')
- self.verbose[CHANNEL_PRUNING] = True
- def check_all_modules(self):
- # check all different compression methods we have
- self.check_weight_quantization()
- self.check_activation_quantization()
- self.check_sparse_pruning()
- self.check_head_pruning()
- self.check_row_pruning()
- self.check_channel_pruning()
- def step(self, step_zero_check=False):
- if not step_zero_check:
- self.training_steps += 1
- self.check_all_modules()
|