multiprocess_utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. import traceback
  3. from functools import partial
  4. from tqdm import tqdm
  5. def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None):
  6. ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
  7. while True:
  8. args = args_queue.get()
  9. if args == '<KILL>':
  10. return
  11. job_idx, map_func, arg = args
  12. try:
  13. map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func
  14. if isinstance(arg, dict):
  15. res = map_func_(**arg)
  16. elif isinstance(arg, (list, tuple)):
  17. res = map_func_(*arg)
  18. else:
  19. res = map_func_(arg)
  20. results_queue.put((job_idx, res))
  21. except:
  22. traceback.print_exc()
  23. results_queue.put((job_idx, None))
  24. class MultiprocessManager:
  25. def __init__(self, num_workers=None, init_ctx_func=None, multithread=False, queue_max=-1):
  26. if multithread:
  27. from multiprocessing.dummy import Queue, Process
  28. else:
  29. from multiprocessing import Queue, Process
  30. if num_workers is None:
  31. num_workers = int(os.getenv('N_PROC', os.cpu_count()))
  32. self.num_workers = num_workers
  33. self.results_queue = Queue(maxsize=-1)
  34. self.jobs_pending = []
  35. self.args_queue = Queue(maxsize=queue_max)
  36. self.workers = []
  37. self.total_jobs = 0
  38. self.multithread = multithread
  39. for i in range(num_workers):
  40. if multithread:
  41. p = Process(target=chunked_worker,
  42. args=(i, self.args_queue, self.results_queue, init_ctx_func))
  43. else:
  44. p = Process(target=chunked_worker,
  45. args=(i, self.args_queue, self.results_queue, init_ctx_func),
  46. daemon=True)
  47. self.workers.append(p)
  48. p.start()
  49. def add_job(self, func, args):
  50. if not self.args_queue.full():
  51. self.args_queue.put((self.total_jobs, func, args))
  52. else:
  53. self.jobs_pending.append((self.total_jobs, func, args))
  54. self.total_jobs += 1
  55. def get_results(self):
  56. self.n_finished = 0
  57. while self.n_finished < self.total_jobs:
  58. while len(self.jobs_pending) > 0 and not self.args_queue.full():
  59. self.args_queue.put(self.jobs_pending[0])
  60. self.jobs_pending = self.jobs_pending[1:]
  61. job_id, res = self.results_queue.get()
  62. yield job_id, res
  63. self.n_finished += 1
  64. for w in range(self.num_workers):
  65. self.args_queue.put("<KILL>")
  66. for w in self.workers:
  67. w.join()
  68. def close(self):
  69. if not self.multithread:
  70. for w in self.workers:
  71. w.terminate()
  72. def __len__(self):
  73. return self.total_jobs
  74. def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None,
  75. multithread=False, queue_max=-1, desc=None):
  76. for i, res in tqdm(
  77. multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread,
  78. queue_max=queue_max),
  79. total=len(args), desc=desc):
  80. yield i, res
  81. def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False,
  82. queue_max=-1):
  83. """
  84. Multiprocessing running chunked jobs.
  85. Examples:
  86. >>> for res in tqdm(multiprocess_run(job_func, args):
  87. >>> print(res)
  88. :param map_func:
  89. :param args:
  90. :param num_workers:
  91. :param ordered:
  92. :param init_ctx_func:
  93. :param q_max_size:
  94. :param multithread:
  95. :return:
  96. """
  97. if num_workers is None:
  98. num_workers = int(os.getenv('N_PROC', os.cpu_count()))
  99. # num_workers = 1
  100. manager = MultiprocessManager(num_workers, init_ctx_func, multithread, queue_max=queue_max)
  101. for arg in args:
  102. manager.add_job(map_func, arg)
  103. if ordered:
  104. n_jobs = len(args)
  105. results = ['<WAIT>' for _ in range(n_jobs)]
  106. i_now = 0
  107. for job_i, res in manager.get_results():
  108. results[job_i] = res
  109. while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != '<WAIT>'):
  110. yield i_now, results[i_now]
  111. results[i_now] = None
  112. i_now += 1
  113. else:
  114. for job_i, res in manager.get_results():
  115. yield job_i, res
  116. manager.close()