test_compression.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import pytest
  6. import random
  7. import numpy as np
  8. from unit.megatron_model import get_gpt2_model
  9. from deepspeed.compression.compress import init_compression
  10. from unit.modeling import BertConfig
  11. from unit.modelingpreln import BertEncoder as BertEncoderPreln
  12. from deepspeed.compression.basic_layer import LinearLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
  13. from deepspeed.compression.helper import convert_conv1d_to_linear
  14. from deepspeed.accelerator import get_accelerator
  15. from deepspeed.utils.torch import required_torch_version
  16. from unit.common import DistributedTest
  17. pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.5),
  18. reason='Megatron-LM package requires Pytorch version 1.5 or above')
  19. def reset_random(seed=1234):
  20. random.seed(seed)
  21. np.random.seed(seed)
  22. torch.manual_seed(seed)
  23. get_accelerator().manual_seed_all(seed)
  24. def create_bert_model():
  25. hidden_size = 384
  26. num_layers = 2
  27. heads = 12
  28. dropout_ratio = 0.1
  29. bert_config = BertConfig(vocab_size_or_config_json_file=119547,
  30. hidden_size=hidden_size,
  31. num_hidden_layers=num_layers,
  32. num_attention_heads=heads,
  33. intermediate_size=hidden_size * 4,
  34. hidden_act="gelu",
  35. hidden_dropout_prob=dropout_ratio,
  36. attention_probs_dropout_prob=dropout_ratio,
  37. max_position_embeddings=512,
  38. type_vocab_size=2,
  39. initializer_range=0.2)
  40. weights = []
  41. biases = []
  42. for i in range(4):
  43. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size, hidden_size)))
  44. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  45. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size * 4, hidden_size)))
  46. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size, hidden_size * 4)))
  47. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  48. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  49. for i in range(4):
  50. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  51. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size * 4)))
  52. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  53. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  54. return BertEncoderPreln(bert_config, weights, biases)
  55. class Conv1D(torch.nn.Module):
  56. """
  57. 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
  58. Basically works like a linear layer but the weights are transposed.
  59. Args:
  60. nf (`int`): The number of output features.
  61. nx (`int`): The number of input features.
  62. """
  63. def __init__(self, nf, nx):
  64. super().__init__()
  65. self.nf = nf
  66. w = torch.empty(nx, nf)
  67. self.weight = torch.nn.Parameter(w)
  68. self.bias = torch.nn.Parameter(torch.zeros(nf))
  69. def forward(self, x):
  70. size_out = x.size()[:-1] + (self.nf, )
  71. x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
  72. x = x.view(size_out)
  73. return x
  74. def create_conv1d_model():
  75. nf = 128
  76. nx = 128
  77. return torch.nn.ModuleList([Conv1D(nf, nx) for i in range(4)])
  78. class TestCompression(DistributedTest):
  79. def setup_method(self, method):
  80. reset_random()
  81. def get_ds_config(self):
  82. ds_config_dict = {
  83. "train_micro_batch_size_per_gpu": 1,
  84. "optimizer": {
  85. "type": "Lamb",
  86. "params": {
  87. "lr": 0.00015
  88. }
  89. },
  90. "fp16": {
  91. "enabled": True
  92. },
  93. "compression_training": {
  94. "weight_quantization": {
  95. "shared_parameters": {
  96. "enabled": True,
  97. "quantizer_kernel": False,
  98. "schedule_offset": 50,
  99. "quantize_groups": 1,
  100. "quantize_verbose": False,
  101. "quantization_type": "asymmetric",
  102. "rounding": "nearest",
  103. "fp16_mixed_quantize": {
  104. "enabled": False,
  105. "quantize_change_ratio": 0.001
  106. }
  107. },
  108. "different_groups": {
  109. "wq1": {
  110. "params": {
  111. "start_bits": 12,
  112. "target_bits": 8,
  113. "quantization_period": 50
  114. },
  115. "modules": ["attention.self", "intermediate"]
  116. },
  117. "wq2": {
  118. "params": {
  119. "start_bits": 12,
  120. "target_bits": 4,
  121. "quantization_period": 50
  122. },
  123. "modules": ["attention.output"]
  124. }
  125. }
  126. },
  127. "activation_quantization": {
  128. "shared_parameters": {
  129. "enabled": True,
  130. "quantization_type": "asymmetric",
  131. "range_calibration": "dynamic",
  132. "schedule_offset": 50
  133. },
  134. "different_groups": {
  135. "aq1": {
  136. "params": {
  137. "bits": 8
  138. },
  139. "modules": ["attention.output"]
  140. }
  141. }
  142. },
  143. "sparse_pruning": {
  144. "shared_parameters": {
  145. "enabled": True,
  146. "schedule_offset": 30,
  147. "method": "l1"
  148. },
  149. "different_groups": {
  150. "sp1": {
  151. "params": {
  152. "dense_ratio": 0.5
  153. },
  154. "modules": ["attention.self"]
  155. }
  156. }
  157. },
  158. "row_pruning": {
  159. "shared_parameters": {
  160. "enabled": True,
  161. "schedule_offset": 20,
  162. "method": "topk"
  163. },
  164. "different_groups": {
  165. "rp1": {
  166. "params": {
  167. "dense_ratio": 0.5
  168. },
  169. "modules": ["intermediate.dense"],
  170. "related_modules": [["layer.\\w+.output.dense"]]
  171. }
  172. }
  173. },
  174. "head_pruning": {
  175. "shared_parameters": {
  176. "enabled": True,
  177. "schedule_offset": 10,
  178. "method": "topk",
  179. "num_heads": 12
  180. },
  181. "different_groups": {
  182. "rp1": {
  183. "params": {
  184. "dense_ratio": 0.5
  185. },
  186. "modules": ["attention.output.dense"],
  187. "related_modules": [["self.query", "self.key", "self.value"]]
  188. }
  189. }
  190. }
  191. }
  192. }
  193. return ds_config_dict
  194. def test_linear_layer_compress(self, tmpdir):
  195. model = create_bert_model()
  196. compressed_model = init_compression(model, self.get_ds_config())
  197. assert isinstance(compressed_model.layer[0].attention.self.query, LinearLayer_Compress)
  198. assert isinstance(compressed_model.layer[0].attention.self.key, LinearLayer_Compress)
  199. assert isinstance(compressed_model.layer[0].attention.self.value, LinearLayer_Compress)
  200. @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
  201. def test_mpu_compress(self, tmpdir):
  202. if not required_torch_version(max_version=1.13):
  203. pytest.skip("megatron not compatible with torch >1.13")
  204. from megatron import mpu
  205. args_defaults = {
  206. 'num_layers': 2,
  207. 'hidden_size': 128,
  208. 'num_attention_heads': 8,
  209. 'max_position_embeddings': 128,
  210. }
  211. model = get_gpt2_model(args_defaults)
  212. compressed_model = init_compression(model, self.get_ds_config(), mpu=mpu)
  213. assert isinstance(compressed_model.module.language_model.transformer.layers[0].attention.query_key_value,
  214. ColumnParallelLinear_Compress)
  215. assert isinstance(compressed_model.module.language_model.transformer.layers[0].attention.dense,
  216. RowParallelLinear_Compress)
  217. assert isinstance(compressed_model.module.language_model.transformer.layers[0].mlp.dense_h_to_4h,
  218. ColumnParallelLinear_Compress)
  219. assert isinstance(compressed_model.module.language_model.transformer.layers[0].mlp.dense_4h_to_h,
  220. RowParallelLinear_Compress)
  221. def test_conv1d_convertion(self, tmpdir):
  222. model = create_conv1d_model()
  223. compressed_model = convert_conv1d_to_linear(model, Conv1D)
  224. assert isinstance(compressed_model[0], torch.nn.Linear)
  225. assert isinstance(compressed_model[1], torch.nn.Linear)
  226. assert isinstance(compressed_model[2], torch.nn.Linear)
  227. assert isinstance(compressed_model[3], torch.nn.Linear)