unet.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. '''
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. from torch.nn.parameter import Parameter
  6. from ..policy import DSPolicy
  7. from ...model_implementations.diffusers.unet import DSUNet
  8. class UNetPolicy(DSPolicy):
  9. def __init__(self):
  10. super().__init__()
  11. try:
  12. import diffusers
  13. self._orig_layer_class = diffusers.models.unet_2d_condition.UNet2DConditionModel
  14. except ImportError:
  15. self._orig_layer_class = None
  16. def match(self, module):
  17. return isinstance(module, self._orig_layer_class)
  18. def match_replaced(self, module):
  19. return isinstance(module, DSUNet)
  20. def apply(self, module, enable_cuda_graph=True):
  21. # TODO(cmikeh2): Enable cuda graph should be an inference configuration
  22. return DSUNet(module, enable_cuda_graph=enable_cuda_graph)
  23. def attention(self, client_module):
  24. qw = client_module.to_q.weight
  25. kw = client_module.to_k.weight
  26. vw = client_module.to_v.weight
  27. if qw.shape[1] == kw.shape[1]:
  28. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
  29. return qkvw, \
  30. client_module.to_out[0].weight, \
  31. client_module.to_out[0].bias, \
  32. qw.shape[-1], \
  33. client_module.heads
  34. else:
  35. #return None
  36. #kvw = Parameter(torch.cat((kw, vw), dim=0), requires_grad=False)
  37. return qw, \
  38. kw, vw, \
  39. client_module.to_out[0].weight, \
  40. client_module.to_out[0].bias, \
  41. qw.shape[-1], \
  42. client_module.heads