__init__.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import numpy as np
  2. from .common import CommonInpainter, OfflineInpainter
  3. from .inpainting_aot import AotInpainter
  4. from .inpainting_lama_mpe import LamaMPEInpainter, LamaLargeInpainter
  5. from .inpainting_sd import StableDiffusionInpainter
  6. from .none import NoneInpainter
  7. from .original import OriginalInpainter
  8. INPAINTERS = {
  9. 'default': AotInpainter,
  10. 'lama_large': LamaLargeInpainter,
  11. 'lama_mpe': LamaMPEInpainter,
  12. 'sd': StableDiffusionInpainter,
  13. 'none': NoneInpainter,
  14. 'original': OriginalInpainter,
  15. }
  16. inpainter_cache = {}
  17. def get_inpainter(key: str, *args, **kwargs) -> CommonInpainter:
  18. if key not in INPAINTERS:
  19. raise ValueError(f'Could not find inpainter for: "{key}". Choose from the following: %s' % ','.join(INPAINTERS))
  20. if not inpainter_cache.get(key):
  21. inpainter = INPAINTERS[key]
  22. inpainter_cache[key] = inpainter(*args, **kwargs)
  23. return inpainter_cache[key]
  24. async def prepare(inpainter_key: str, device: str = 'cpu'):
  25. inpainter = get_inpainter(inpainter_key)
  26. if isinstance(inpainter, OfflineInpainter):
  27. await inpainter.download()
  28. await inpainter.load(device)
  29. async def dispatch(inpainter_key: str, image: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, device: str = 'cpu', verbose: bool = False) -> np.ndarray:
  30. inpainter = get_inpainter(inpainter_key)
  31. if isinstance(inpainter, OfflineInpainter):
  32. await inpainter.load(device)
  33. return await inpainter.inpaint(image, mask, inpainting_size, verbose)