config.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. """
  3. Copyright (c) Microsoft Corporation
  4. Licensed under the MIT license.
  5. """
  6. from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
  7. from deepspeed.profiling.constants import *
  8. class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
  9. def __init__(self, param_dict):
  10. super(DeepSpeedFlopsProfilerConfig, self).__init__()
  11. self.enabled = None
  12. self.profile_step = None
  13. self.module_depth = None
  14. self.top_modules = None
  15. if FLOPS_PROFILER in param_dict.keys():
  16. flops_profiler_dict = param_dict[FLOPS_PROFILER]
  17. else:
  18. flops_profiler_dict = {}
  19. self._initialize(flops_profiler_dict)
  20. def _initialize(self, flops_profiler_dict):
  21. self.enabled = get_scalar_param(flops_profiler_dict,
  22. FLOPS_PROFILER_ENABLED,
  23. FLOPS_PROFILER_ENABLED_DEFAULT)
  24. self.profile_step = get_scalar_param(flops_profiler_dict,
  25. FLOPS_PROFILER_PROFILE_STEP,
  26. FLOPS_PROFILER_PROFILE_STEP_DEFAULT)
  27. self.module_depth = get_scalar_param(flops_profiler_dict,
  28. FLOPS_PROFILER_MODULE_DEPTH,
  29. FLOPS_PROFILER_MODULE_DEPTH_DEFAULT)
  30. self.top_modules = get_scalar_param(flops_profiler_dict,
  31. FLOPS_PROFILER_TOP_MODULES,
  32. FLOPS_PROFILER_TOP_MODULES_DEFAULT)
  33. self.detailed = get_scalar_param(flops_profiler_dict,
  34. FLOPS_PROFILER_DETAILED,
  35. FLOPS_PROFILER_DETAILED_DEFAULT)
  36. self.output_file = get_scalar_param(flops_profiler_dict,
  37. FLOPS_PROFILER_OUTPUT_FILE,
  38. FLOPS_PROFILER_OUTPUT_FILE_DEFAULT)