config.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
  5. from deepspeed.profiling.constants import *
  6. class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
  7. def __init__(self, param_dict):
  8. super(DeepSpeedFlopsProfilerConfig, self).__init__()
  9. self.enabled = None
  10. self.recompute_fwd_factor = None
  11. self.profile_step = None
  12. self.module_depth = None
  13. self.top_modules = None
  14. if FLOPS_PROFILER in param_dict.keys():
  15. flops_profiler_dict = param_dict[FLOPS_PROFILER]
  16. else:
  17. flops_profiler_dict = {}
  18. self._initialize(flops_profiler_dict)
  19. def _initialize(self, flops_profiler_dict):
  20. self.enabled = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED_DEFAULT)
  21. self.recompute_fwd_factor = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_RECOMPUTE_FWD_FACTOR,
  22. FLOPS_PROFILER_RECOMPUTE_FWD_FACTOR_DEFAULT)
  23. self.profile_step = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_PROFILE_STEP,
  24. FLOPS_PROFILER_PROFILE_STEP_DEFAULT)
  25. self.module_depth = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_MODULE_DEPTH,
  26. FLOPS_PROFILER_MODULE_DEPTH_DEFAULT)
  27. self.top_modules = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_TOP_MODULES,
  28. FLOPS_PROFILER_TOP_MODULES_DEFAULT)
  29. self.detailed = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_DETAILED, FLOPS_PROFILER_DETAILED_DEFAULT)
  30. self.output_file = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_OUTPUT_FILE,
  31. FLOPS_PROFILER_OUTPUT_FILE_DEFAULT)