multi_tensor_apply.py 429 B

1234567891011121314151617
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from NVIDIA/apex, commit a109f85
  7. """
  8. class MultiTensorApply(object):
  9. def __init__(self, chunk_size):
  10. self.chunk_size = chunk_size
  11. def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
  12. return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)