1234567891011121314151617181920212223242526272829303132333435363738 |
- import numpy as np
- from .common import CommonInpainter, OfflineInpainter
- from .inpainting_aot import AotInpainter
- from .inpainting_lama_mpe import LamaMPEInpainter, LamaLargeInpainter
- from .inpainting_sd import StableDiffusionInpainter
- from .none import NoneInpainter
- from .original import OriginalInpainter
- INPAINTERS = {
- 'default': AotInpainter,
- 'lama_large': LamaLargeInpainter,
- 'lama_mpe': LamaMPEInpainter,
- 'sd': StableDiffusionInpainter,
- 'none': NoneInpainter,
- 'original': OriginalInpainter,
- }
- inpainter_cache = {}
- def get_inpainter(key: str, *args, **kwargs) -> CommonInpainter:
- if key not in INPAINTERS:
- raise ValueError(f'Could not find inpainter for: "{key}". Choose from the following: %s' % ','.join(INPAINTERS))
- if not inpainter_cache.get(key):
- inpainter = INPAINTERS[key]
- inpainter_cache[key] = inpainter(*args, **kwargs)
- return inpainter_cache[key]
- async def prepare(inpainter_key: str, device: str = 'cpu'):
- inpainter = get_inpainter(inpainter_key)
- if isinstance(inpainter, OfflineInpainter):
- await inpainter.download()
- await inpainter.load(device)
- async def dispatch(inpainter_key: str, image: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, device: str = 'cpu', verbose: bool = False) -> np.ndarray:
- inpainter = get_inpainter(inpainter_key)
- if isinstance(inpainter, OfflineInpainter):
- await inpainter.load(device)
- return await inpainter.inpaint(image, mask, inpainting_size, verbose)
|