test_compression.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import torch
  2. import pytest
  3. import random
  4. import numpy as np
  5. from .megatron_model import get_gpt2_model
  6. from deepspeed.compression.compress import init_compression
  7. from .modeling import BertConfig
  8. from .modelingpreln import BertEncoder as BertEncoderPreln
  9. from deepspeed.compression.basic_layer import LinearLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
  10. from deepspeed.compression.helper import convert_conv1d_to_linear
  11. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  12. TORCH_MINOR = int(torch.__version__.split('.')[1])
  13. pytestmark = pytest.mark.skipif(
  14. TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 5),
  15. reason='Megatron-LM package requires Pytorch version 1.5 or above')
  16. def reset_random(seed=1234):
  17. random.seed(seed)
  18. np.random.seed(seed)
  19. torch.manual_seed(seed)
  20. torch.cuda.manual_seed_all(seed)
  21. def create_bert_model():
  22. hidden_size = 384
  23. num_layers = 2
  24. heads = 12
  25. dropout_ratio = 0.1
  26. bert_config = BertConfig(vocab_size_or_config_json_file=119547,
  27. hidden_size=hidden_size,
  28. num_hidden_layers=num_layers,
  29. num_attention_heads=heads,
  30. intermediate_size=hidden_size * 4,
  31. hidden_act="gelu",
  32. hidden_dropout_prob=dropout_ratio,
  33. attention_probs_dropout_prob=dropout_ratio,
  34. max_position_embeddings=512,
  35. type_vocab_size=2,
  36. initializer_range=0.2)
  37. weights = []
  38. biases = []
  39. for i in range(4):
  40. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size, hidden_size)))
  41. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  42. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size * 4, hidden_size)))
  43. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size, hidden_size * 4)))
  44. weights.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  45. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  46. for i in range(4):
  47. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  48. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size * 4)))
  49. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  50. biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
  51. return BertEncoderPreln(bert_config, weights, biases)
  52. class Conv1D(torch.nn.Module):
  53. """
  54. 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
  55. Basically works like a linear layer but the weights are transposed.
  56. Args:
  57. nf (`int`): The number of output features.
  58. nx (`int`): The number of input features.
  59. """
  60. def __init__(self, nf, nx):
  61. super().__init__()
  62. self.nf = nf
  63. w = torch.empty(nx, nf)
  64. self.weight = torch.nn.Parameter(w)
  65. self.bias = torch.nn.Parameter(torch.zeros(nf))
  66. def forward(self, x):
  67. size_out = x.size()[:-1] + (self.nf, )
  68. x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
  69. x = x.view(size_out)
  70. return x
  71. def create_conv1d_model():
  72. nf = 128
  73. nx = 128
  74. return torch.nn.ModuleList([Conv1D(nf, nx) for i in range(4)])
  75. class TestCompression:
  76. def setup_method(self, method):
  77. reset_random()
  78. def get_ds_config(self):
  79. ds_config_dict = {
  80. "train_micro_batch_size_per_gpu": 1,
  81. "optimizer": {
  82. "type": "Lamb",
  83. "params": {
  84. "lr": 0.00015
  85. }
  86. },
  87. "fp16": {
  88. "enabled": True
  89. },
  90. "compression_training": {
  91. "weight_quantization": {
  92. "shared_parameters": {
  93. "enabled": True,
  94. "quantizer_kernel": False,
  95. "schedule_offset": 50,
  96. "quantize_groups": 1,
  97. "quantize_verbose": False,
  98. "quantization_type": "asymmetric",
  99. "rounding": "nearest",
  100. "fp16_mixed_quantize": {
  101. "enabled": False,
  102. "quantize_change_ratio": 0.001
  103. }
  104. },
  105. "different_groups": {
  106. "wq1": {
  107. "params": {
  108. "start_bits": 12,
  109. "target_bits": 8,
  110. "quantization_period": 50
  111. },
  112. "modules": ["attention.self",
  113. "intermediate"]
  114. },
  115. "wq2": {
  116. "params": {
  117. "start_bits": 12,
  118. "target_bits": 4,
  119. "quantization_period": 50
  120. },
  121. "modules": ["attention.output"]
  122. }
  123. }
  124. },
  125. "activation_quantization": {
  126. "shared_parameters": {
  127. "enabled": True,
  128. "quantization_type": "asymmetric",
  129. "range_calibration": "dynamic",
  130. "schedule_offset": 50
  131. },
  132. "different_groups": {
  133. "aq1": {
  134. "params": {
  135. "bits": 8
  136. },
  137. "modules": ["attention.output"]
  138. }
  139. }
  140. },
  141. "sparse_pruning": {
  142. "shared_parameters": {
  143. "enabled": True,
  144. "schedule_offset": 30,
  145. "method": "l1"
  146. },
  147. "different_groups": {
  148. "sp1": {
  149. "params": {
  150. "dense_ratio": 0.5
  151. },
  152. "modules": ["attention.self"]
  153. }
  154. }
  155. },
  156. "row_pruning": {
  157. "shared_parameters": {
  158. "enabled": True,
  159. "schedule_offset": 20,
  160. "method": "topk"
  161. },
  162. "different_groups": {
  163. "rp1": {
  164. "params": {
  165. "dense_ratio": 0.5
  166. },
  167. "modules": ["intermediate.dense"],
  168. "related_modules": [["layer.\\w+.output.dense"]]
  169. }
  170. }
  171. },
  172. "head_pruning": {
  173. "shared_parameters": {
  174. "enabled": True,
  175. "schedule_offset": 10,
  176. "method": "topk",
  177. "num_heads": 12
  178. },
  179. "different_groups": {
  180. "rp1": {
  181. "params": {
  182. "dense_ratio": 0.5
  183. },
  184. "modules": ["attention.output.dense"],
  185. "related_modules": [["self.query",
  186. "self.key",
  187. "self.value"]]
  188. }
  189. }
  190. }
  191. }
  192. }
  193. return ds_config_dict
  194. def test_linear_layer_compress(self, tmpdir):
  195. model = create_bert_model()
  196. compressed_model = init_compression(model, self.get_ds_config())
  197. assert isinstance(compressed_model.layer[0].attention.self.query,
  198. LinearLayer_Compress)
  199. assert isinstance(compressed_model.layer[0].attention.self.key,
  200. LinearLayer_Compress)
  201. assert isinstance(compressed_model.layer[0].attention.self.value,
  202. LinearLayer_Compress)
  203. def test_mpu_compress(self, tmpdir):
  204. from megatron import mpu
  205. args_defaults = {
  206. 'num_layers': 2,
  207. 'hidden_size': 128,
  208. 'num_attention_heads': 8,
  209. 'max_position_embeddings': 128,
  210. }
  211. model = get_gpt2_model(args_defaults)
  212. compressed_model = init_compression(model, self.get_ds_config(), mpu=mpu)
  213. assert isinstance(
  214. compressed_model.module.language_model.transformer.layers[0].attention.
  215. query_key_value,
  216. ColumnParallelLinear_Compress)
  217. assert isinstance(
  218. compressed_model.module.language_model.transformer.layers[0].attention.dense,
  219. RowParallelLinear_Compress)
  220. assert isinstance(
  221. compressed_model.module.language_model.transformer.layers[0].mlp.
  222. dense_h_to_4h,
  223. ColumnParallelLinear_Compress)
  224. assert isinstance(
  225. compressed_model.module.language_model.transformer.layers[0].mlp.
  226. dense_4h_to_h,
  227. RowParallelLinear_Compress)
  228. def test_conv1d_convertion(self, tmpdir):
  229. model = create_conv1d_model()
  230. compressed_model = convert_conv1d_to_linear(model, Conv1D)
  231. assert isinstance(compressed_model[0], torch.nn.Linear)
  232. assert isinstance(compressed_model[1], torch.nn.Linear)
  233. assert isinstance(compressed_model[2], torch.nn.Linear)
  234. assert isinstance(compressed_model[3], torch.nn.Linear)