test.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import os
  2. import argparse
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from torch.utils.data import Dataset, DataLoader
  7. from tqdm.asyncio import tqdm
  8. from data_loader import create_training_datasets
  9. from train import AnimeSegmentation, net_names
  10. from inference import get_mask
  11. import warnings
  12. # warnings.filterwarnings("ignore")
  13. def main(opt):
  14. train_dataset, _ = create_training_datasets(opt.data_dir, opt.fg_dir, opt.bg_dir, opt.img_dir,
  15. opt.mask_dir, opt.fg_ext, opt.bg_ext, opt.img_ext,
  16. opt.mask_ext, 1, opt.img_size)
  17. salobj_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)
  18. device = torch.device(opt.device)
  19. model = AnimeSegmentation.try_load(opt.net, opt.ckpt, img_size=opt.img_size)
  20. model.eval()
  21. model.to(device)
  22. if not os.path.exists(opt.out):
  23. os.mkdir(opt.out)
  24. for i, data in enumerate(tqdm(salobj_dataloader)):
  25. image, label = data["image"][0], data["label"][0]
  26. image = image.permute(1, 2, 0).numpy() * 255
  27. label = label.permute(1, 2, 0).numpy() * 255
  28. mask = get_mask(model, image, use_amp=not opt.fp32, s=opt.img_size)
  29. image = np.concatenate((image, mask.repeat(3, 2) * 255, label.repeat(3, 2)), axis=1).astype(np.uint8)
  30. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  31. cv2.imwrite(f'{opt.out}/{i:06d}.jpg', image)
  32. if __name__ == "__main__":
  33. parser = argparse.ArgumentParser()
  34. # model args
  35. parser.add_argument('--net', type=str, default='isnet_is',
  36. choices=net_names,
  37. help='net name')
  38. parser.add_argument('--ckpt', type=str, default='saved_models/isnetis.ckpt',
  39. help='resume training from ckpt')
  40. parser.add_argument('--out', type=str, default='out',
  41. help='output dir')
  42. parser.add_argument('--img-size', type=int, default=1024,
  43. help='input image size')
  44. # dataset args
  45. parser.add_argument('--data-dir', type=str, default='../../dataset/anime-seg',
  46. help='root dir of dataset')
  47. parser.add_argument('--fg-dir', type=str, default='fg',
  48. help='relative dir of foreground')
  49. parser.add_argument('--bg-dir', type=str, default='bg',
  50. help='relative dir of background')
  51. parser.add_argument('--img-dir', type=str, default='imgs',
  52. help='relative dir of images')
  53. parser.add_argument('--mask-dir', type=str, default='masks',
  54. help='relative dir of masks')
  55. parser.add_argument('--fg-ext', type=str, default='.png',
  56. help='extension name of foreground')
  57. parser.add_argument('--bg-ext', type=str, default='.jpg',
  58. help='extension name of background')
  59. parser.add_argument('--img-ext', type=str, default='.jpg',
  60. help='extension name of images')
  61. parser.add_argument('--mask-ext', type=str, default='.jpg',
  62. help='extension name of masks')
  63. parser.add_argument('--device', type=str, default='cuda:0',
  64. help='cpu or cuda:0')
  65. parser.add_argument('--fp32', action='store_true', default=False,
  66. help='disable mix precision')
  67. opt = parser.parse_args()
  68. print(opt)
  69. main(opt)