""" Copyright 2020 The Microsoft DeepSpeed Team Licensed under the MIT license. Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ import os import torch from deepspeed.utils.logging import logger from deepspeed.ops.aio import AsyncIOBuilder from deepspeed.runtime.swap_tensor.constants import * from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object, \ MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers, get_sized_buffer from deepspeed.runtime.swap_tensor.async_swapper import AsyncTensorSwapper from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper DEBUG_MODE = False SWAP_IN_PARAM_TIMER = 'swap_in_param' SWAP_OUT_PARAM_TIMER = 'swap_out_param' SWAP_IN_GRADIENT_TIMER = 'swap_in_gradient' class PartitionedOptimizerSwapper(OptimizerSwapper): def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers): super(PartitionedOptimizerSwapper, self).__init__(swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers) aio_op = AsyncIOBuilder().load() self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], aio_config[AIO_SINGLE_SUBMIT], aio_config[AIO_OVERLAP_EVENTS], aio_config[AIO_THREAD_COUNT]) # Overlap swapping out self.gradient_swapper = AsyncTensorSwapper(aio_handle=self.aio_handle, numel_alignment=self.numel_alignment, timers=self.timers) self.print_exclude_list += [ 'aio_handle', 'gradient_swapper', 'print_exclude_list' ] if torch.distributed.get_rank() == 0: print_object(obj=self, name='PartitionedOptimizerSwapper', exclude_list=self.print_exclude_list) def initialize_parameters(self, parameters, src_tensors): self._initialize_parameters(parameters=parameters, src_tensors=src_tensors, aio_handle=self.aio_handle) def initialize_from_swapped_fp16_params(self, fp16_partitions_info, fp16_num_elems, fp16_pinned_buffers, fp32_parameters): self._initialize_from_swapped_fp16_params( aio_handle=self.aio_handle, fp16_partitions_info=fp16_partitions_info, fp16_num_elems=fp16_num_elems, fp16_pinned_buffers=fp16_pinned_buffers, fp32_parameters=fp32_parameters) def flush_gradients(self): self._flush_gradient_swapper(self.gradient_swapper) def swap_in_optimizer_state(self, parameter, async_parameter=None): swap_info = self._get_param_swap_info(parameter) if swap_info is None: return self._flush_gradient_swapper(self.gradient_swapper) required_buffer_count = len( swap_info.tensors) + (1 if swap_info.has_gradients() else 0) aligned_numel = self._io_aligned_numel(swap_info.numel()) pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel, count=required_buffer_count, dtype=parameter.dtype) assert pinned_buffers is not None self.allocated_swap_buffers = pinned_buffers.copy() self._start_timer(SWAP_IN_PARAM_TIMER) self._swap_in_parameter(aio_handle=self.aio_handle, parameter=parameter, dest_buffers=pinned_buffers[:required_buffer_count]) self._stop_timer(SWAP_IN_PARAM_TIMER) self.timer_names.add(SWAP_IN_PARAM_TIMER) self._start_timer(SWAP_IN_GRADIENT_TIMER) self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1]) self._stop_timer(SWAP_IN_GRADIENT_TIMER) self.timer_names.add(SWAP_IN_GRADIENT_TIMER) def swap_out_optimizer_state(self, parameter, async_swap=False): swap_info = self._get_param_swap_info(parameter=parameter) if swap_info is None: return self._start_timer(SWAP_OUT_PARAM_TIMER) pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info) swap_bytes = sum([ self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors ]) WRITE_TIMER = 'swap_submit_write' self._start_timer(WRITE_TIMER) swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths) assert self.aio_handle.wait() == len(pinned_tensors) for t in pinned_tensors: t.data = torch.Tensor() if len(unpinned_tensors) > 0: pinned_buffers = self.swap_buffer_manager.allocate_all( num_elems=self.largest_numel, dtype=self.dtype) self._swap_out_unpinned_tensors(aio_handle=self.aio_handle, unpinned_tensors=unpinned_tensors, dest_paths=unpinned_paths, pinned_buffers=pinned_buffers) self.allocated_swap_buffers += pinned_buffers for t in unpinned_tensors: t.data = torch.Tensor() self._stop_timer(WRITE_TIMER) self.swap_buffer_manager.free(self.allocated_swap_buffers) self.allocated_swap_buffers = [] self._stop_timer(SWAP_OUT_PARAM_TIMER) self.timer_names.add(SWAP_OUT_PARAM_TIMER) self._log_timers([WRITE_TIMER]) if DEBUG_MODE and torch.distributed.get_rank() == 0: logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB') def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors): self._swap_out_gradients(parameter=parameter, gradient_offsets=gradient_offsets, gradient_tensors=gradient_tensors, gradient_swapper=self.gradient_swapper) def _swap_in_parameter(self, aio_handle, parameter, dest_buffers): swap_info = self._get_param_swap_info(parameter) if swap_info is None: return assert len(swap_info.tensors) <= len(dest_buffers) swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len( swap_info.tensors) swap_buffers = get_sized_buffers(dest_buffers, swap_lengths) READ_TIMER = 'swap_submit_read_param' WAIT_TIMER = 'swap_wait_read_param' self._start_timer(READ_TIMER) swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths) self._stop_timer(READ_TIMER) swap_bytes = sum( [buffer.numel() * buffer.element_size() for buffer in swap_buffers]) self._start_timer(WAIT_TIMER) aio_handle.wait() self._stop_timer(WAIT_TIMER) compute_lengths = [swap_info.numel()] * len(swap_info.tensors) compute_buffers = get_sized_buffers(dest_buffers, compute_lengths) for t, buffer in zip(swap_info.tensors, compute_buffers): t.data = buffer.data self._log_timers([READ_TIMER, WAIT_TIMER]) if DEBUG_MODE and torch.distributed.get_rank() == 0: logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB') def _separate_pinned_tensors(self, swap_info): pinned_tensors = [] pinned_paths = [] unpinned_tensors = [] unpinned_paths = [] for tensor, path in zip(swap_info.tensors, swap_info.swap_paths): if tensor.is_pinned(): pinned_tensors.append(tensor) pinned_paths.append(path) else: unpinned_tensors.append(tensor) unpinned_paths.append(path) return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor): swap_info = self.swap_params_info[id(parameter)] param_gradients = swap_info.swapped_gradients.values() swap_buffers = [ gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients ] swap_paths = [grad.path for grad in param_gradients] SWAP_READ_GRADIENTS = 'swap_submit_read_gradient' SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient' self._start_timer(SWAP_READ_GRADIENTS) swap_in_tensors(aio_handle, swap_buffers, swap_paths) self._stop_timer(SWAP_READ_GRADIENTS) self._start_timer(SWAP_WAIT_GRADIENTS) assert len(swap_buffers) == aio_handle.wait() self._stop_timer(SWAP_WAIT_GRADIENTS) self._log_timers([SWAP_READ_GRADIENTS, SWAP_WAIT_GRADIENTS]) def _swap_in_gradients(self, aio_handle, parameter, dest_buffer): swap_info = self.swap_params_info.get(id(parameter), None) if not (swap_info and swap_info.has_gradients()): return assert dest_buffer.is_pinned() assert parameter.numel() <= dest_buffer.numel() parameter.grad = dest_buffer.narrow(0, 0, parameter.numel()) if swap_info.swapped_gradients: self._swap_in_pinned_gradients(aio_handle, parameter, parameter.grad) if swap_info.unswapped_gradients: self._retrieve_unswapped_grad_partitions(swap_info=swap_info, dest_buffer=parameter.grad)