gated_mlp.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import abstractmethod
  5. from .hybrid_engine import HybridEngineContainer
  6. class HybridGatedMLPContainer(HybridEngineContainer):
  7. """
  8. The HybridGatedMLPContainer supports models for which the first MLP layer
  9. is represented with two separate weights, one for the activation function
  10. and one for the gating function.
  11. """
  12. def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
  13. super().set_mlp(_h4h_w, _h4h_b, _4hh_w, _4hh_b)
  14. self.set_mlp_gate()
  15. @abstractmethod
  16. def set_mlp_gate(self):
  17. """
  18. In `set_mlp_gate`, it is necessary to populate the following variables (where appropriate)
  19. for the given model:
  20. self.inter_up_w: inter up weight
  21. self.inter_up_b: inter up bias
  22. self.inter_gate_w: inter gate weight
  23. self.inter_gate_b: inter gate bias
  24. If the parameter does not exist in the original model, set the attribute to None.
  25. """
  26. raise NotImplementedError("A set_mlp_gate() function must be defined in the model container \
  27. in order to set the unfused inter up and gate tensors.")
  28. def mlp_inter_mp(self, mp_replace, reversed_dim=False):
  29. # Only need to alter behavior if we can't do the normal destructive copy
  30. if self.module.mlp.inter_w is None:
  31. params = [
  32. (self.module.mlp.inter_up_w, self.inter_up_w),
  33. (self.module.mlp.inter_up_b, self.inter_up_b),
  34. (self.module.mlp.inter_gate_w, self.inter_gate_w),
  35. (self.module.mlp.inter_gate_b, self.inter_gate_b),
  36. ]
  37. for dst, src in params:
  38. dst = mp_replace.copy(dst[:self.inter_up_w.shape[0] // mp_replace.mp_size],
  39. src,
  40. int8=reversed_dim,
  41. allocate_tensor=reversed_dim) if src is not None else None
  42. else:
  43. self.module.mlp.inter_w = mp_replace.strided_copy(self.module.mlp.inter_w,
  44. self._h4h_w,
  45. num_splits=2,
  46. int8=reversed_dim)
  47. self.module.mlp.inter_b = mp_replace.strided_copy(self.module.mlp.inter_b,
  48. self._h4h_b,
  49. num_splits=2,
  50. int8=reversed_dim)
  51. def release_mlp(self):
  52. super().release_mlp()
  53. gated_mlp_params = [
  54. (self.module.mlp.inter_up_w, self.inter_up_w),
  55. (self.module.mlp.inter_up_b, self.inter_up_b),
  56. (self.module.mlp.inter_gate_w, self.inter_gate_w),
  57. (self.module.mlp.inter_gate_b, self.inter_gate_b),
  58. ]
  59. self._release_params(gated_mlp_params)
  60. def reset_mlp(self):
  61. self._h4h_w.data[:self.inter_up_w.shape[0]] = self.inter_up_w.data
  62. self._h4h_w.data[self.inter_up_w.shape[0]:] = self.inter_gate_w.data
  63. if self.inter_up_b is not None:
  64. self._h4h_b.data[:self.inter_up_b.shape[0]] = self.inter_up_b.data
  65. self._h4h_b.data[self.inter_up_b.shape[0]:] = self.inter_gate_b.data
  66. inter_data = [self.inter_up_w.data, self.inter_gate_w.data]
  67. if self.inter_up_b is not None:
  68. inter_data.extend([self.inter_up_b.data, self.inter_gate_b.data])
  69. self.inter_up_w.data = self._h4h_w.data[:self.inter_up_w.shape[0]]
  70. self.inter_gate_w.data = self._h4h_w.data[self.inter_up_w.shape[0]:]
  71. if self.inter_up_b is not None:
  72. self.inter_up_b.data = self._h4h_b.data[:self.inter_up_b.shape[0]]
  73. self.inter_gate_b.data = self._h4h_b.data[self.inter_up_b.shape[0]:]
  74. for data in inter_data:
  75. del data
  76. def set_mlp_params_wo_copy(self, Z3_enabled=False):
  77. self.module.mlp.output_w = self._4hh_w
  78. self.module.mlp.output_b = self._4hh_b
  79. if not Z3_enabled:
  80. # In initialize_tensors, we create a fused inter projection with the appropriate shape
  81. # and copy the up projection and gate projection into it
  82. self.module.mlp.inter_w = self._h4h_w
  83. self.module.mlp.inter_b = self._h4h_b
  84. self.inter_up_w.data = self._h4h_w[:self.inter_up_w.shape[0], :]
  85. self.inter_gate_w.data = self._h4h_w[self.inter_up_w.shape[0]:, :]
  86. if self.inter_up_b is not None:
  87. self.inter_up_b.data = self._h4h_b[:self.inter_up_w.shape[0]] if self._h4h_b is not None else None
  88. self.inter_gate_b.data = self._h4h_b[self.inter_up_w.shape[0]:] if self._h4h_b is not None else None
  89. else:
  90. self.module.mlp.inter_up_w = self.inter_up_w
  91. self.module.mlp.inter_up_b = self.inter_up_b
  92. self.module.mlp.inter_gate_w = self.inter_gate_w
  93. self.module.mlp.inter_gate_b = self.inter_gate_b
  94. def get_mlp_params(self):
  95. params = super().get_mlp_params()
  96. params.extend([self.inter_up_w, self.inter_up_b, self.inter_gate_w, self.inter_gate_b])
  97. return params