1234567891011121314151617 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- """
- Copyright NVIDIA/apex
- This file is adapted from NVIDIA/apex, commit a109f85
- """
- class MultiTensorApply(object):
- def __init__(self, chunk_size):
- self.chunk_size = chunk_size
- def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
- return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
|