inject.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import copy
  5. import torch
  6. from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  7. def module_inject(layer_obj, model, config, micro_batch_size, max_seq_length, seed, preln, fp16=True):
  8. for name, child in model.named_children():
  9. if isinstance(child, layer_obj):
  10. print('REPLACING BertLayer')
  11. cuda_config = DeepSpeedTransformerConfig(batch_size=micro_batch_size,
  12. max_seq_length=max_seq_length,
  13. hidden_size=config.hidden_size,
  14. heads=config.num_attention_heads,
  15. attn_dropout_ratio=config.attention_probs_dropout_prob,
  16. hidden_dropout_ratio=config.hidden_dropout_prob,
  17. num_hidden_layers=config.num_hidden_layers,
  18. initializer_range=config.initializer_range,
  19. seed=seed,
  20. fp16=fp16,
  21. pre_layer_norm=preln)
  22. new_module = DeepSpeedTransformerLayer(cuda_config)
  23. # copy relevant state from child -> new module
  24. qw = child.attention.self.query.weight
  25. qb = child.attention.self.query.bias
  26. kw = child.attention.self.key.weight
  27. kb = child.attention.self.key.bias
  28. vw = child.attention.self.value.weight
  29. vb = child.attention.self.value.bias
  30. qkvw = torch.cat((qw, kw, vw), 0)
  31. qkvb = torch.cat((qb, kb, vb), 0)
  32. new_module.attn_qkvw.data = qkvw
  33. new_module.attn_qkvb.data = qkvb
  34. new_module.attn_ow.data = child.attention.output.dense.weight
  35. new_module.attn_ob.data = child.attention.output.dense.bias
  36. if preln:
  37. attention_layerNorm = child.PostAttentionLayerNorm
  38. else:
  39. attention_layerNorm = child.attention.output.LayerNorm
  40. new_module.attn_nw.data = attention_layerNorm.weight
  41. new_module.attn_nb.data = attention_layerNorm.bias
  42. if preln:
  43. intermediate_FF = child.intermediate.dense_act
  44. else:
  45. intermediate_FF = child.intermediate.dense
  46. new_module.inter_w.data = intermediate_FF.weight
  47. new_module.inter_b.data = intermediate_FF.bias
  48. new_module.output_w.data = child.output.dense.weight
  49. new_module.output_b.data = child.output.dense.bias
  50. if preln:
  51. transformer_LayerNorm = child.PreAttentionLayerNorm
  52. else:
  53. transformer_LayerNorm = child.output.LayerNorm
  54. new_module.norm_w.data = transformer_LayerNorm.weight
  55. new_module.norm_b.data = transformer_LayerNorm.bias
  56. setattr(model, name, copy.deepcopy(new_module))
  57. else:
  58. module_inject(layer_obj, child, config, micro_batch_size, max_seq_length, seed, preln, fp16)
  59. return model
  60. def test_hi():
  61. from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
  62. from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
  63. from turing.nvidia_modelingpreln import BertLayer
  64. bert_model_config = {
  65. "vocab_size_or_config_json_file": 119547,
  66. "hidden_size": 1024,
  67. "num_hidden_layers": 1,
  68. "num_attention_heads": 16,
  69. "intermediate_size": 4096,
  70. "hidden_act": "gelu",
  71. "hidden_dropout_prob": 0.1,
  72. "attention_probs_dropout_prob": 0.1,
  73. "hidden_dropout_prob": 0.1,
  74. "attention_probs_dropout_prob": 0.1,
  75. "max_position_embeddings": 512,
  76. "type_vocab_size": 2,
  77. "initializer_range": 0.02
  78. }
  79. bert_config = BertConfigPreLN(**bert_model_config)
  80. base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)
  81. #base_model = LinearStack()
  82. test_model = copy.deepcopy(base_model)
  83. test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)
  84. print('BASE', base_model)
  85. print('TEST', test_model)
  86. #base_model.eval()
  87. #test_model.eval()
  88. #test_input = torch.rand(1, base_model.input_dim)
  89. #base_output = base_model(test_input)
  90. #test_output = test_model(test_input)
  91. #
  92. #assert torch.allclose(base_output, test_output, atol=3e-8)