clip.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .base import *
  5. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. from ..policy import TransformerPolicy
  9. class DS_CLIPContainer(BaseTransformerContainer):
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. # All model specific things should be defined here instead of the base class.
  13. def create_module(self, config=None):
  14. _config = config if config is not None else self.ds_model_config
  15. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  16. self.module.config.scale_attention = self.scale_attention
  17. return self.module
  18. class HFCLIPLayerPolicy(TransformerPolicy):
  19. def __init__(self, client_module, inference=False):
  20. super().__init__(inference, pre_attn_norm=True, scale_attention=True)
  21. self.client_module = client_module
  22. self.cuda_graph_supported = True
  23. if HFCLIPLayerPolicy._orig_layer_class is None:
  24. try:
  25. import transformers
  26. HFCLIPLayerPolicy._orig_layer_class = transformers.models.clip.modeling_clip.CLIPEncoderLayer
  27. except:
  28. HFCLIPLayerPolicy._orig_layer_class = None
  29. def get_hidden_heads(self):
  30. return self.client_module.self_attn.q_proj.weight.shape[1], \
  31. self.client_module.self_attn.num_heads, \
  32. self.client_module.layer_norm1.eps, \
  33. DEFAULT_INTERMEDIATE_SIZE
  34. def attention(self, enable_training=False):
  35. qw = self.client_module.self_attn.q_proj.weight
  36. qb = self.client_module.self_attn.q_proj.bias
  37. kw = self.client_module.self_attn.k_proj.weight
  38. kb = self.client_module.self_attn.k_proj.bias
  39. vw = self.client_module.self_attn.v_proj.weight
  40. vb = self.client_module.self_attn.v_proj.bias
  41. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
  42. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
  43. return qkvw, \
  44. qkvb, \
  45. self.client_module.self_attn.out_proj.weight, \
  46. self.client_module.self_attn.out_proj.bias
  47. def mlp(self, enable_training=False):
  48. return self.client_module.mlp.fc1.weight, \
  49. self.client_module.mlp.fc1.bias, \
  50. self.client_module.mlp.fc2.weight, \
  51. self.client_module.mlp.fc2.bias
  52. def layernorm(self):
  53. return self.client_module.layer_norm2.weight, \
  54. self.client_module.layer_norm2.bias, \
  55. self.client_module.layer_norm1.weight, \
  56. self.client_module.layer_norm1.bias