vae.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from ..policy import DSPolicy
  5. from ...model_implementations.diffusers.vae import DSVAE
  6. class VAEPolicy(DSPolicy):
  7. def __init__(self):
  8. super().__init__()
  9. try:
  10. import diffusers
  11. if hasattr(diffusers.models.vae, "AutoencoderKL"):
  12. self._orig_layer_class = diffusers.models.vae.AutoencoderKL
  13. else:
  14. # Diffusers >= 0.12.0 changes location of AutoencoderKL
  15. self._orig_layer_class = diffusers.models.autoencoder_kl.AutoencoderKL
  16. except ImportError:
  17. self._orig_layer_class = None
  18. def match(self, module):
  19. return isinstance(module, self._orig_layer_class)
  20. def match_replaced(self, module):
  21. return isinstance(module, DSVAE)
  22. def apply(self, module, enable_cuda_graph=True):
  23. # TODO(cmikeh2): Enable cuda graph should be an inference configuration
  24. return DSVAE(module, enable_cuda_graph=enable_cuda_graph)
  25. # NOTE (lekurile): Should we have a diffusers policy class?
  26. def attention(self, client_module):
  27. pass