1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import os
- import argparse
- import cv2
- import numpy as np
- import torch
- from torch.utils.data import Dataset, DataLoader
- from tqdm.asyncio import tqdm
- from data_loader import create_training_datasets
- from train import AnimeSegmentation, net_names
- from inference import get_mask
- import warnings
- # warnings.filterwarnings("ignore")
- def main(opt):
- train_dataset, _ = create_training_datasets(opt.data_dir, opt.fg_dir, opt.bg_dir, opt.img_dir,
- opt.mask_dir, opt.fg_ext, opt.bg_ext, opt.img_ext,
- opt.mask_ext, 1, opt.img_size)
- salobj_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)
- device = torch.device(opt.device)
- model = AnimeSegmentation.try_load(opt.net, opt.ckpt, img_size=opt.img_size)
- model.eval()
- model.to(device)
- if not os.path.exists(opt.out):
- os.mkdir(opt.out)
- for i, data in enumerate(tqdm(salobj_dataloader)):
- image, label = data["image"][0], data["label"][0]
- image = image.permute(1, 2, 0).numpy() * 255
- label = label.permute(1, 2, 0).numpy() * 255
- mask = get_mask(model, image, use_amp=not opt.fp32, s=opt.img_size)
- image = np.concatenate((image, mask.repeat(3, 2) * 255, label.repeat(3, 2)), axis=1).astype(np.uint8)
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- cv2.imwrite(f'{opt.out}/{i:06d}.jpg', image)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- # model args
- parser.add_argument('--net', type=str, default='isnet_is',
- choices=net_names,
- help='net name')
- parser.add_argument('--ckpt', type=str, default='saved_models/isnetis.ckpt',
- help='resume training from ckpt')
- parser.add_argument('--out', type=str, default='out',
- help='output dir')
- parser.add_argument('--img-size', type=int, default=1024,
- help='input image size')
- # dataset args
- parser.add_argument('--data-dir', type=str, default='../../dataset/anime-seg',
- help='root dir of dataset')
- parser.add_argument('--fg-dir', type=str, default='fg',
- help='relative dir of foreground')
- parser.add_argument('--bg-dir', type=str, default='bg',
- help='relative dir of background')
- parser.add_argument('--img-dir', type=str, default='imgs',
- help='relative dir of images')
- parser.add_argument('--mask-dir', type=str, default='masks',
- help='relative dir of masks')
- parser.add_argument('--fg-ext', type=str, default='.png',
- help='extension name of foreground')
- parser.add_argument('--bg-ext', type=str, default='.jpg',
- help='extension name of background')
- parser.add_argument('--img-ext', type=str, default='.jpg',
- help='extension name of images')
- parser.add_argument('--mask-ext', type=str, default='.jpg',
- help='extension name of masks')
- parser.add_argument('--device', type=str, default='cuda:0',
- help='cpu or cuda:0')
- parser.add_argument('--fp32', action='store_true', default=False,
- help='disable mix precision')
- opt = parser.parse_args()
- print(opt)
- main(opt)
|