inject.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import copy
  2. import torch
  3. from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  4. def module_inject(layer_obj,
  5. model,
  6. config,
  7. micro_batch_size,
  8. max_seq_length,
  9. seed,
  10. preln,
  11. fp16=True):
  12. for name, child in model.named_children():
  13. if isinstance(child, layer_obj):
  14. print('REPLACING BertLayer')
  15. cuda_config = DeepSpeedTransformerConfig(
  16. batch_size=micro_batch_size,
  17. max_seq_length=max_seq_length,
  18. hidden_size=config.hidden_size,
  19. heads=config.num_attention_heads,
  20. attn_dropout_ratio=config.attention_probs_dropout_prob,
  21. hidden_dropout_ratio=config.hidden_dropout_prob,
  22. num_hidden_layers=config.num_hidden_layers,
  23. initializer_range=config.initializer_range,
  24. seed=seed,
  25. fp16=fp16,
  26. pre_layer_norm=preln)
  27. new_module = DeepSpeedTransformerLayer(cuda_config)
  28. # copy relevant state from child -> new module
  29. qw = child.attention.self.query.weight
  30. qb = child.attention.self.query.bias
  31. kw = child.attention.self.key.weight
  32. kb = child.attention.self.key.bias
  33. vw = child.attention.self.value.weight
  34. vb = child.attention.self.value.bias
  35. qkvw = torch.cat((qw, kw, vw), 0)
  36. qkvb = torch.cat((qb, kb, vb), 0)
  37. new_module.attn_qkvw.data = qkvw
  38. new_module.attn_qkvb.data = qkvb
  39. new_module.attn_ow.data = child.attention.output.dense.weight
  40. new_module.attn_ob.data = child.attention.output.dense.bias
  41. if preln:
  42. attention_layerNorm = child.PostAttentionLayerNorm
  43. else:
  44. attention_layerNorm = child.attention.output.LayerNorm
  45. new_module.attn_nw.data = attention_layerNorm.weight
  46. new_module.attn_nb.data = attention_layerNorm.bias
  47. if preln:
  48. intermediate_FF = child.intermediate.dense_act
  49. else:
  50. intermediate_FF = child.intermediate.dense
  51. new_module.inter_w.data = intermediate_FF.weight
  52. new_module.inter_b.data = intermediate_FF.bias
  53. new_module.output_w.data = child.output.dense.weight
  54. new_module.output_b.data = child.output.dense.bias
  55. if preln:
  56. transformer_LayerNorm = child.PreAttentionLayerNorm
  57. else:
  58. transformer_LayerNorm = child.output.LayerNorm
  59. new_module.norm_w.data = transformer_LayerNorm.weight
  60. new_module.norm_b.data = transformer_LayerNorm.bias
  61. setattr(model, name, copy.deepcopy(new_module))
  62. else:
  63. module_inject(layer_obj,
  64. child,
  65. config,
  66. micro_batch_size,
  67. max_seq_length,
  68. seed,
  69. preln,
  70. fp16)
  71. return model
  72. def test_hi():
  73. from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
  74. from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
  75. from turing.nvidia_modelingpreln import BertLayer
  76. bert_model_config = {
  77. "vocab_size_or_config_json_file": 119547,
  78. "hidden_size": 1024,
  79. "num_hidden_layers": 1,
  80. "num_attention_heads": 16,
  81. "intermediate_size": 4096,
  82. "hidden_act": "gelu",
  83. "hidden_dropout_prob": 0.1,
  84. "attention_probs_dropout_prob": 0.1,
  85. "hidden_dropout_prob": 0.1,
  86. "attention_probs_dropout_prob": 0.1,
  87. "max_position_embeddings": 512,
  88. "type_vocab_size": 2,
  89. "initializer_range": 0.02
  90. }
  91. bert_config = BertConfigPreLN(**bert_model_config)
  92. base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)
  93. #base_model = LinearStack()
  94. test_model = copy.deepcopy(base_model)
  95. test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)
  96. print('BASE', base_model)
  97. print('TEST', test_model)
  98. #base_model.eval()
  99. #test_model.eval()
  100. #test_input = torch.rand(1, base_model.input_dim)
  101. #base_output = base_model(test_input)
  102. #test_output = test_model(test_input)
  103. #
  104. #assert torch.allclose(base_output, test_output, atol=3e-8)