clip.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  4. import torch
  5. from torch.nn.parameter import Parameter
  6. from ..policy import TransformerPolicy
  7. class DS_CLIPContainer(BaseTransformerContainer):
  8. def __init__(self, **kwargs):
  9. super().__init__(**kwargs)
  10. # All model specific things should be defined here instead of the base class.
  11. def create_module(self, config=None):
  12. _config = config if config is not None else self.ds_model_config
  13. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  14. self.module.config.scale_attention = self.scale_attention
  15. return self.module
  16. class HFCLIPLayerPolicy(TransformerPolicy):
  17. def __init__(self, client_module, inference=False):
  18. super().__init__(inference, pre_attn_norm=True, scale_attention=True)
  19. self.client_module = client_module
  20. self.cuda_graph_supported = True
  21. if HFCLIPLayerPolicy._orig_layer_class is None:
  22. try:
  23. import transformers
  24. HFCLIPLayerPolicy._orig_layer_class = transformers.models.clip.modeling_clip.CLIPEncoderLayer
  25. except:
  26. HFCLIPLayerPolicy._orig_layer_class = None
  27. def get_hidden_heads(self):
  28. return self.client_module.self_attn.q_proj.weight.shape[1], \
  29. self.client_module.self_attn.num_heads
  30. def attention(self):
  31. qw = self.client_module.self_attn.q_proj.weight
  32. qb = self.client_module.self_attn.q_proj.bias
  33. kw = self.client_module.self_attn.k_proj.weight
  34. kb = self.client_module.self_attn.k_proj.bias
  35. vw = self.client_module.self_attn.v_proj.weight
  36. vb = self.client_module.self_attn.v_proj.bias
  37. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
  38. qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
  39. return qkvw, \
  40. qkvb, \
  41. self.client_module.self_attn.out_proj.weight, \
  42. self.client_module.self_attn.out_proj.bias
  43. def mlp(self):
  44. return self.client_module.mlp.fc1.weight, \
  45. self.client_module.mlp.fc1.bias, \
  46. self.client_module.mlp.fc2.weight, \
  47. self.client_module.mlp.fc2.bias
  48. def layernorm(self):
  49. return self.client_module.layer_norm2.weight, \
  50. self.client_module.layer_norm2.bias, \
  51. self.client_module.layer_norm1.weight, \
  52. self.client_module.layer_norm1.bias