dataloader_test.py 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. import glob
  2. import os
  3. import time
  4. import cv2
  5. from torch.utils.data import DataLoader
  6. from data_loader import create_training_datasets
  7. import numpy as np
  8. if __name__ == '__main__':
  9. data_dir = '../../dataset/anime-seg/'
  10. tra_fg_dir = 'fg/'
  11. tra_bg_dir = 'bg/'
  12. tra_img_dir = 'imgs/'
  13. tra_mask_dir = 'masks/'
  14. fg_ext = '.png'
  15. bg_ext = '.*'
  16. img_ext = '.jpg'
  17. mask_ext = '.jpg'
  18. train_dataset, val_dataset = create_training_datasets(data_dir, tra_fg_dir, tra_bg_dir, tra_img_dir, tra_mask_dir,
  19. fg_ext, bg_ext, img_ext, mask_ext, 0.95, 640, True)
  20. salobj_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, persistent_workers=True)
  21. for data in salobj_dataloader:
  22. cv2.imshow("a", np.concatenate([data['image'][0].permute(1, 2, 0).numpy()[:, :, ::-1],
  23. cv2.cvtColor(data['label'][0].permute(1, 2, 0).numpy(), cv2.COLOR_GRAY2RGB),
  24. cv2.cvtColor(data['trimap'][0].permute(1, 2, 0).numpy(), cv2.COLOR_GRAY2RGB)],
  25. axis=1))
  26. cv2.waitKey(1000)