reader.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
  4. @description:
  5. """
  6. import json
  7. import os
  8. import torch
  9. from torch.utils.data import DataLoader
  10. from torch.utils.data import Dataset
  11. from transformers import BertTokenizerFast
  12. class DataCollator:
  13. def __init__(self, tokenizer: BertTokenizerFast):
  14. self.tokenizer = tokenizer
  15. def __call__(self, data):
  16. ori_texts, cor_texts, wrong_idss = zip(*data)
  17. encoded_texts = [self.tokenizer(t, return_offsets_mapping=True, add_special_tokens=False) for t in ori_texts]
  18. max_len = max([len(t['input_ids']) for t in encoded_texts]) + 2
  19. det_labels = torch.zeros(len(ori_texts), max_len).long()
  20. for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)):
  21. off_mapping = encoded_text['offset_mapping']
  22. for idx in wrong_ids:
  23. for j, (b, e) in enumerate(off_mapping):
  24. if b <= idx < e:
  25. # j+1是因为前面的 CLS token
  26. det_labels[i, j + 1] = 1
  27. break
  28. return list(ori_texts), list(cor_texts), det_labels
  29. class CscDataset(Dataset):
  30. def __init__(self, file_path):
  31. self.data = json.load(open(file_path, 'r', encoding='utf-8'))
  32. def __len__(self):
  33. return len(self.data)
  34. def __getitem__(self, index):
  35. return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids']
  36. def make_loaders(collate_fn, train_path='', valid_path='', test_path='',
  37. batch_size=32, num_workers=4):
  38. train_loader = None
  39. if train_path and os.path.exists(train_path):
  40. train_loader = DataLoader(
  41. CscDataset(train_path),
  42. batch_size=batch_size,
  43. shuffle=False,
  44. num_workers=num_workers,
  45. collate_fn=collate_fn
  46. )
  47. valid_loader = None
  48. if valid_path and os.path.exists(valid_path):
  49. valid_loader = DataLoader(
  50. CscDataset(valid_path),
  51. batch_size=batch_size,
  52. num_workers=num_workers,
  53. collate_fn=collate_fn
  54. )
  55. test_loader = None
  56. if test_path and os.path.exists(test_path):
  57. test_loader = DataLoader(
  58. CscDataset(test_path),
  59. batch_size=batch_size,
  60. num_workers=num_workers,
  61. collate_fn=collate_fn
  62. )
  63. return train_loader, valid_loader, test_loader