123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from abc import abstractmethod
- from .hybrid_engine import HybridEngineContainer
- class HybridGatedMLPContainer(HybridEngineContainer):
- """
- The HybridGatedMLPContainer supports models for which the first MLP layer
- is represented with two separate weights, one for the activation function
- and one for the gating function.
- """
- def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
- super().set_mlp(_h4h_w, _h4h_b, _4hh_w, _4hh_b)
- self.set_mlp_gate()
- @abstractmethod
- def set_mlp_gate(self):
- """
- In `set_mlp_gate`, it is necessary to populate the following variables (where appropriate)
- for the given model:
- self.inter_up_w: inter up weight
- self.inter_up_b: inter up bias
- self.inter_gate_w: inter gate weight
- self.inter_gate_b: inter gate bias
- If the parameter does not exist in the original model, set the attribute to None.
- """
- raise NotImplementedError("A set_mlp_gate() function must be defined in the model container \
- in order to set the unfused inter up and gate tensors.")
- def mlp_inter_mp(self, mp_replace, reversed_dim=False):
- # Only need to alter behavior if we can't do the normal destructive copy
- if self.module.mlp.inter_w is None:
- params = [
- (self.module.mlp.inter_up_w, self.inter_up_w),
- (self.module.mlp.inter_up_b, self.inter_up_b),
- (self.module.mlp.inter_gate_w, self.inter_gate_w),
- (self.module.mlp.inter_gate_b, self.inter_gate_b),
- ]
- for dst, src in params:
- dst = mp_replace.copy(dst[:self.inter_up_w.shape[0] // mp_replace.mp_size],
- src,
- int8=reversed_dim,
- allocate_tensor=reversed_dim) if src is not None else None
- else:
- self.module.mlp.inter_w = mp_replace.strided_copy(self.module.mlp.inter_w,
- self._h4h_w,
- num_splits=2,
- int8=reversed_dim)
- self.module.mlp.inter_b = mp_replace.strided_copy(self.module.mlp.inter_b,
- self._h4h_b,
- num_splits=2,
- int8=reversed_dim)
- def release_mlp(self):
- super().release_mlp()
- gated_mlp_params = [
- (self.module.mlp.inter_up_w, self.inter_up_w),
- (self.module.mlp.inter_up_b, self.inter_up_b),
- (self.module.mlp.inter_gate_w, self.inter_gate_w),
- (self.module.mlp.inter_gate_b, self.inter_gate_b),
- ]
- self._release_params(gated_mlp_params)
- def reset_mlp(self):
- self._h4h_w.data[:self.inter_up_w.shape[0]] = self.inter_up_w.data
- self._h4h_w.data[self.inter_up_w.shape[0]:] = self.inter_gate_w.data
- if self.inter_up_b is not None:
- self._h4h_b.data[:self.inter_up_b.shape[0]] = self.inter_up_b.data
- self._h4h_b.data[self.inter_up_b.shape[0]:] = self.inter_gate_b.data
- inter_data = [self.inter_up_w.data, self.inter_gate_w.data]
- if self.inter_up_b is not None:
- inter_data.extend([self.inter_up_b.data, self.inter_gate_b.data])
- self.inter_up_w.data = self._h4h_w.data[:self.inter_up_w.shape[0]]
- self.inter_gate_w.data = self._h4h_w.data[self.inter_up_w.shape[0]:]
- if self.inter_up_b is not None:
- self.inter_up_b.data = self._h4h_b.data[:self.inter_up_b.shape[0]]
- self.inter_gate_b.data = self._h4h_b.data[self.inter_up_b.shape[0]:]
- for data in inter_data:
- del data
- def set_mlp_params_wo_copy(self, Z3_enabled=False):
- self.module.mlp.output_w = self._4hh_w
- self.module.mlp.output_b = self._4hh_b
- if not Z3_enabled:
- # In initialize_tensors, we create a fused inter projection with the appropriate shape
- # and copy the up projection and gate projection into it
- self.module.mlp.inter_w = self._h4h_w
- self.module.mlp.inter_b = self._h4h_b
- self.inter_up_w.data = self._h4h_w[:self.inter_up_w.shape[0], :]
- self.inter_gate_w.data = self._h4h_w[self.inter_up_w.shape[0]:, :]
- if self.inter_up_b is not None:
- self.inter_up_b.data = self._h4h_b[:self.inter_up_w.shape[0]] if self._h4h_b is not None else None
- self.inter_gate_b.data = self._h4h_b[self.inter_up_w.shape[0]:] if self._h4h_b is not None else None
- else:
- self.module.mlp.inter_up_w = self.inter_up_w
- self.module.mlp.inter_up_b = self.inter_up_b
- self.module.mlp.inter_gate_w = self.inter_gate_w
- self.module.mlp.inter_gate_b = self.inter_gate_b
- def get_mlp_params(self):
- params = super().get_mlp_params()
- params.extend([self.inter_up_w, self.inter_up_b, self.inter_gate_w, self.inter_gate_b])
- return params
|