multi_tensor_apply.py 378 B

1234567891011121314
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/apex
  4. This file is adapted from NVIDIA/apex, commit a109f85
  5. '''
  6. class MultiTensorApply(object):
  7. def __init__(self, chunk_size):
  8. self.chunk_size = chunk_size
  9. def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
  10. return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)