inject.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import copy
  3. import torch
  4. from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  5. def module_inject(layer_obj,
  6. model,
  7. config,
  8. micro_batch_size,
  9. max_seq_length,
  10. seed,
  11. preln,
  12. fp16=True):
  13. for name, child in model.named_children():
  14. if isinstance(child, layer_obj):
  15. print('REPLACING BertLayer')
  16. cuda_config = DeepSpeedTransformerConfig(
  17. batch_size=micro_batch_size,
  18. max_seq_length=max_seq_length,
  19. hidden_size=config.hidden_size,
  20. heads=config.num_attention_heads,
  21. attn_dropout_ratio=config.attention_probs_dropout_prob,
  22. hidden_dropout_ratio=config.hidden_dropout_prob,
  23. num_hidden_layers=config.num_hidden_layers,
  24. initializer_range=config.initializer_range,
  25. seed=seed,
  26. fp16=fp16,
  27. pre_layer_norm=preln)
  28. new_module = DeepSpeedTransformerLayer(cuda_config)
  29. # copy relevant state from child -> new module
  30. qw = child.attention.self.query.weight
  31. qb = child.attention.self.query.bias
  32. kw = child.attention.self.key.weight
  33. kb = child.attention.self.key.bias
  34. vw = child.attention.self.value.weight
  35. vb = child.attention.self.value.bias
  36. qkvw = torch.cat((qw, kw, vw), 0)
  37. qkvb = torch.cat((qb, kb, vb), 0)
  38. new_module.attn_qkvw.data = qkvw
  39. new_module.attn_qkvb.data = qkvb
  40. new_module.attn_ow.data = child.attention.output.dense.weight
  41. new_module.attn_ob.data = child.attention.output.dense.bias
  42. if preln:
  43. attention_layerNorm = child.PostAttentionLayerNorm
  44. else:
  45. attention_layerNorm = child.attention.output.LayerNorm
  46. new_module.attn_nw.data = attention_layerNorm.weight
  47. new_module.attn_nb.data = attention_layerNorm.bias
  48. if preln:
  49. intermediate_FF = child.intermediate.dense_act
  50. else:
  51. intermediate_FF = child.intermediate.dense
  52. new_module.inter_w.data = intermediate_FF.weight
  53. new_module.inter_b.data = intermediate_FF.bias
  54. new_module.output_w.data = child.output.dense.weight
  55. new_module.output_b.data = child.output.dense.bias
  56. if preln:
  57. transformer_LayerNorm = child.PreAttentionLayerNorm
  58. else:
  59. transformer_LayerNorm = child.output.LayerNorm
  60. new_module.norm_w.data = transformer_LayerNorm.weight
  61. new_module.norm_b.data = transformer_LayerNorm.bias
  62. setattr(model, name, copy.deepcopy(new_module))
  63. else:
  64. module_inject(layer_obj,
  65. child,
  66. config,
  67. micro_batch_size,
  68. max_seq_length,
  69. seed,
  70. preln,
  71. fp16)
  72. return model
  73. def test_hi():
  74. from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
  75. from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
  76. from turing.nvidia_modelingpreln import BertLayer
  77. bert_model_config = {
  78. "vocab_size_or_config_json_file": 119547,
  79. "hidden_size": 1024,
  80. "num_hidden_layers": 1,
  81. "num_attention_heads": 16,
  82. "intermediate_size": 4096,
  83. "hidden_act": "gelu",
  84. "hidden_dropout_prob": 0.1,
  85. "attention_probs_dropout_prob": 0.1,
  86. "hidden_dropout_prob": 0.1,
  87. "attention_probs_dropout_prob": 0.1,
  88. "max_position_embeddings": 512,
  89. "type_vocab_size": 2,
  90. "initializer_range": 0.02
  91. }
  92. bert_config = BertConfigPreLN(**bert_model_config)
  93. base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)
  94. #base_model = LinearStack()
  95. test_model = copy.deepcopy(base_model)
  96. test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)
  97. print('BASE', base_model)
  98. print('TEST', test_model)
  99. #base_model.eval()
  100. #test_model.eval()
  101. #test_input = torch.rand(1, base_model.input_dim)
  102. #base_output = base_model(test_input)
  103. #test_output = test_model(test_input)
  104. #
  105. #assert torch.allclose(base_output, test_output, atol=3e-8)