import numpy as np from .common import CommonInpainter, OfflineInpainter from .inpainting_aot import AotInpainter from .inpainting_lama_mpe import LamaMPEInpainter, LamaLargeInpainter from .inpainting_sd import StableDiffusionInpainter from .none import NoneInpainter from .original import OriginalInpainter INPAINTERS = { 'default': AotInpainter, 'lama_large': LamaLargeInpainter, 'lama_mpe': LamaMPEInpainter, 'sd': StableDiffusionInpainter, 'none': NoneInpainter, 'original': OriginalInpainter, } inpainter_cache = {} def get_inpainter(key: str, *args, **kwargs) -> CommonInpainter: if key not in INPAINTERS: raise ValueError(f'Could not find inpainter for: "{key}". Choose from the following: %s' % ','.join(INPAINTERS)) if not inpainter_cache.get(key): inpainter = INPAINTERS[key] inpainter_cache[key] = inpainter(*args, **kwargs) return inpainter_cache[key] async def prepare(inpainter_key: str, device: str = 'cpu'): inpainter = get_inpainter(inpainter_key) if isinstance(inpainter, OfflineInpainter): await inpainter.download() await inpainter.load(device) async def dispatch(inpainter_key: str, image: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, device: str = 'cpu', verbose: bool = False) -> np.ndarray: inpainter = get_inpainter(inpainter_key) if isinstance(inpainter, OfflineInpainter): await inpainter.load(device) return await inpainter.inpaint(image, mask, inpainting_size, verbose)