test_stable_diffusion.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import pytest
  7. import deepspeed
  8. import numpy
  9. from unit.common import DistributedTest
  10. from deepspeed.accelerator import get_accelerator
  11. # Setup for these models is different from other pipelines, so we add a separate test
  12. @pytest.mark.stable_diffusion
  13. class TestStableDiffusion(DistributedTest):
  14. world_size = 1
  15. def test(self):
  16. from diffusers import DiffusionPipeline
  17. from image_similarity_measures.quality_metrics import rmse
  18. dev = get_accelerator().device_name()
  19. generator = torch.Generator(device=dev)
  20. seed = 0xABEDABE7
  21. generator.manual_seed(seed)
  22. prompt = "a dog on a rocket"
  23. model = "prompthero/midjourney-v4-diffusion"
  24. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  25. device = torch.device(f"{dev}:{local_rank}")
  26. pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half)
  27. pipe = pipe.to(device)
  28. baseline_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0]
  29. pipe = deepspeed.init_inference(
  30. pipe,
  31. mp_size=1,
  32. dtype=torch.half,
  33. replace_with_kernel_inject=True,
  34. enable_cuda_graph=True,
  35. )
  36. generator.manual_seed(seed)
  37. deepspeed_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0]
  38. rmse_value = rmse(org_img=numpy.asarray(baseline_image), pred_img=numpy.asarray(deepspeed_image))
  39. # RMSE threshold value is arbitrary, may need to adjust as needed
  40. assert rmse_value <= 0.01