quantization.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from torch import nn
  6. from typing import Dict
  7. import gc
  8. from deepspeed.inference.quantization import layers
  9. from .layers import QUANTIZATION_LAYER_MAPPINGS
  10. from .utils import get_AsyncPartitionedParameterSwapper, recursive_setattr
  11. from deepspeed.utils.logging import logger
  12. from collections import deque
  13. from transformers.utils.generic import ContextManagers
  14. from .quantization_context import QuantizationContext
  15. import contextlib
  16. def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> nn.Module:
  17. """[Experimental] Apply group-wise weight quantization to model. Replace layers module according to config_list
  18. Args:
  19. model (nn.Module): A nn.Module
  20. ds_config (Dict, optional): The ds_config dictionary. use None for non-deepspeed managed model.
  21. Returns:
  22. nn.Module: Quantized nn.Module
  23. """
  24. # global quantized_weight_registry
  25. matched_module_list_by_key = {}
  26. matched_module_count = 0
  27. assert 'weight_quantization' in ds_config, 'Please provide quantization config in ds_config'
  28. quantization_config = ds_config['weight_quantization']['post_init_quant']
  29. # Return nvme swapper if exists, else return None.
  30. # For nvme offloading we must use the same swapper here as model initialized.
  31. nvme_swapper = get_AsyncPartitionedParameterSwapper(model)
  32. is_zero3_enabled = 'zero_optimization' in ds_config and \
  33. 'stage' in ds_config['zero_optimization'] and \
  34. ds_config['zero_optimization']['stage'] == 3
  35. is_offloading_enabled = 'zero_optimization' in ds_config and \
  36. 'offload_param' in ds_config['zero_optimization']
  37. layers.is_zero3_enabled = is_zero3_enabled
  38. context_mgr = ContextManagers([QuantizationContext(config_dict_or_path=ds_config, param_swapper=nvme_swapper)]) \
  39. if is_zero3_enabled else contextlib.suppress()
  40. with context_mgr:
  41. module_list = list(
  42. filter(lambda named_module: type(named_module[1]) in QUANTIZATION_LAYER_MAPPINGS, model.named_modules()))
  43. # Quantize small weight first then large.
  44. if not is_offloading_enabled:
  45. module_list.sort(key=lambda named_module: named_module[1].weight.ds_tensor.numel()
  46. if is_zero3_enabled else named_module[1].weight.numel())
  47. module_list = deque(module_list)
  48. while len(module_list) > 0:
  49. # Use popleft to timely release module's memory of replaced module after each loop iteration
  50. module_name, module = module_list.popleft()
  51. matched_key = None
  52. matched_quantization_config = None
  53. for key, config in quantization_config.items():
  54. if key in module_name:
  55. assert matched_key is None, f'{module_name} matched multiple quantization key word {matched_key} and {key}'
  56. matched_key = key
  57. matched_quantization_config = config
  58. if matched_key is None:
  59. continue
  60. if is_zero3_enabled:
  61. module.weight.all_gather()
  62. assert module.weight.dtype == torch.float16, 'Model weight is expected in half.'
  63. new_module = QUANTIZATION_LAYER_MAPPINGS[type(module)](matched_quantization_config, module)
  64. if is_zero3_enabled:
  65. module.weight.partition()
  66. recursive_setattr(model, module_name, new_module)
  67. if matched_key not in matched_module_list_by_key:
  68. matched_module_list_by_key[matched_key] = []
  69. matched_module_list_by_key[matched_key].append(module_name)
  70. matched_module_count += 1
  71. # Timely recycle memory to prevent OOM on large models
  72. gc.collect()
  73. # Clear registry after model construction.
  74. layers.quantized_weight_registry.clear()
  75. logger.info(
  76. f'Group-wise weight quantization summary: convert {matched_module_count} node(s) to quantized implementation')
  77. summary_str = '\n'
  78. for key, module_list in matched_module_list_by_key.items():
  79. summary_str += f'Key: {key}, matched modules:\n'
  80. for module_name in module_list:
  81. summary_str += f'\t{module_name}\n'
  82. logger.info(summary_str)
  83. return model