123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- """
- Copyright 2020 The Microsoft DeepSpeed Team
- Licensed under the MIT license.
- Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
- """
- import os
- import torch
- from deepspeed.utils.logging import logger
- from deepspeed.ops.aio import AsyncIOBuilder
- from deepspeed.runtime.swap_tensor.constants import *
- from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object, \
- MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers, get_sized_buffer
- from deepspeed.runtime.swap_tensor.async_swapper import AsyncTensorSwapper
- from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
- DEBUG_MODE = False
- SWAP_IN_PARAM_TIMER = 'swap_in_param'
- SWAP_OUT_PARAM_TIMER = 'swap_out_param'
- SWAP_IN_GRADIENT_TIMER = 'swap_in_gradient'
- class PartitionedOptimizerSwapper(OptimizerSwapper):
- def __init__(self,
- swap_config,
- aio_config,
- base_folder,
- optimizer,
- largest_numel,
- device,
- dtype,
- timers):
- super(PartitionedOptimizerSwapper,
- self).__init__(swap_config,
- aio_config,
- base_folder,
- optimizer,
- largest_numel,
- device,
- dtype,
- timers)
- aio_op = AsyncIOBuilder().load()
- self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE],
- aio_config[AIO_QUEUE_DEPTH],
- aio_config[AIO_SINGLE_SUBMIT],
- aio_config[AIO_OVERLAP_EVENTS],
- aio_config[AIO_THREAD_COUNT])
- # Overlap swapping out
- self.gradient_swapper = AsyncTensorSwapper(aio_handle=self.aio_handle,
- numel_alignment=self.numel_alignment,
- timers=self.timers)
- self.print_exclude_list += [
- 'aio_handle',
- 'gradient_swapper',
- 'print_exclude_list'
- ]
- if torch.distributed.get_rank() == 0:
- print_object(obj=self,
- name='PartitionedOptimizerSwapper',
- exclude_list=self.print_exclude_list)
- def initialize_parameters(self, parameters, src_tensors):
- self._initialize_parameters(parameters=parameters,
- src_tensors=src_tensors,
- aio_handle=self.aio_handle)
- def initialize_from_swapped_fp16_params(self,
- fp16_partitions_info,
- fp16_num_elems,
- fp16_pinned_buffers,
- fp32_parameters):
- self._initialize_from_swapped_fp16_params(
- aio_handle=self.aio_handle,
- fp16_partitions_info=fp16_partitions_info,
- fp16_num_elems=fp16_num_elems,
- fp16_pinned_buffers=fp16_pinned_buffers,
- fp32_parameters=fp32_parameters)
- def flush_gradients(self):
- self._flush_gradient_swapper(self.gradient_swapper)
- def swap_in_optimizer_state(self, parameter, async_parameter=None):
- swap_info = self._get_param_swap_info(parameter)
- if swap_info is None:
- return
- self._flush_gradient_swapper(self.gradient_swapper)
- required_buffer_count = len(
- swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
- aligned_numel = self._io_aligned_numel(swap_info.numel())
- pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel,
- count=required_buffer_count,
- dtype=parameter.dtype)
- assert pinned_buffers is not None
- self.allocated_swap_buffers = pinned_buffers.copy()
- self._start_timer(SWAP_IN_PARAM_TIMER)
- self._swap_in_parameter(aio_handle=self.aio_handle,
- parameter=parameter,
- dest_buffers=pinned_buffers[:required_buffer_count])
- self._stop_timer(SWAP_IN_PARAM_TIMER)
- self.timer_names.add(SWAP_IN_PARAM_TIMER)
- self._start_timer(SWAP_IN_GRADIENT_TIMER)
- self._swap_in_gradients(aio_handle=self.aio_handle,
- parameter=parameter,
- dest_buffer=pinned_buffers[-1])
- self._stop_timer(SWAP_IN_GRADIENT_TIMER)
- self.timer_names.add(SWAP_IN_GRADIENT_TIMER)
- def swap_out_optimizer_state(self, parameter, async_swap=False):
- swap_info = self._get_param_swap_info(parameter=parameter)
- if swap_info is None:
- return
- self._start_timer(SWAP_OUT_PARAM_TIMER)
- pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)
- swap_bytes = sum([
- self._io_aligned_numel(t.numel()) * t.element_size()
- for t in swap_info.tensors
- ])
- WRITE_TIMER = 'swap_submit_write'
- self._start_timer(WRITE_TIMER)
- swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
- assert self.aio_handle.wait() == len(pinned_tensors)
- for t in pinned_tensors:
- t.data = torch.Tensor()
- if len(unpinned_tensors) > 0:
- pinned_buffers = self.swap_buffer_manager.allocate_all(
- num_elems=self.largest_numel,
- dtype=self.dtype)
- self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
- unpinned_tensors=unpinned_tensors,
- dest_paths=unpinned_paths,
- pinned_buffers=pinned_buffers)
- self.allocated_swap_buffers += pinned_buffers
- for t in unpinned_tensors:
- t.data = torch.Tensor()
- self._stop_timer(WRITE_TIMER)
- self.swap_buffer_manager.free(self.allocated_swap_buffers)
- self.allocated_swap_buffers = []
- self._stop_timer(SWAP_OUT_PARAM_TIMER)
- self.timer_names.add(SWAP_OUT_PARAM_TIMER)
- self._log_timers([WRITE_TIMER])
- if DEBUG_MODE and torch.distributed.get_rank() == 0:
- logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
- def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors):
- self._swap_out_gradients(parameter=parameter,
- gradient_offsets=gradient_offsets,
- gradient_tensors=gradient_tensors,
- gradient_swapper=self.gradient_swapper)
- def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
- swap_info = self._get_param_swap_info(parameter)
- if swap_info is None:
- return
- assert len(swap_info.tensors) <= len(dest_buffers)
- swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(
- swap_info.tensors)
- swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)
- READ_TIMER = 'swap_submit_read_param'
- WAIT_TIMER = 'swap_wait_read_param'
- self._start_timer(READ_TIMER)
- swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
- self._stop_timer(READ_TIMER)
- swap_bytes = sum(
- [buffer.numel() * buffer.element_size() for buffer in swap_buffers])
- self._start_timer(WAIT_TIMER)
- aio_handle.wait()
- self._stop_timer(WAIT_TIMER)
- compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
- compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
- for t, buffer in zip(swap_info.tensors, compute_buffers):
- t.data = buffer.data
- self._log_timers([READ_TIMER, WAIT_TIMER])
- if DEBUG_MODE and torch.distributed.get_rank() == 0:
- logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
- def _separate_pinned_tensors(self, swap_info):
- pinned_tensors = []
- pinned_paths = []
- unpinned_tensors = []
- unpinned_paths = []
- for tensor, path in zip(swap_info.tensors, swap_info.swap_paths):
- if tensor.is_pinned():
- pinned_tensors.append(tensor)
- pinned_paths.append(path)
- else:
- unpinned_tensors.append(tensor)
- unpinned_paths.append(path)
- return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths
- def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor):
- swap_info = self.swap_params_info[id(parameter)]
- param_gradients = swap_info.swapped_gradients.values()
- swap_buffers = [
- gradient_tensor.narrow(0,
- grad.offset,
- grad.length) for grad in param_gradients
- ]
- swap_paths = [grad.path for grad in param_gradients]
- SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
- SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'
- self._start_timer(SWAP_READ_GRADIENTS)
- swap_in_tensors(aio_handle, swap_buffers, swap_paths)
- self._stop_timer(SWAP_READ_GRADIENTS)
- self._start_timer(SWAP_WAIT_GRADIENTS)
- assert len(swap_buffers) == aio_handle.wait()
- self._stop_timer(SWAP_WAIT_GRADIENTS)
- self._log_timers([SWAP_READ_GRADIENTS, SWAP_WAIT_GRADIENTS])
- def _swap_in_gradients(self, aio_handle, parameter, dest_buffer):
- swap_info = self.swap_params_info.get(id(parameter), None)
- if not (swap_info and swap_info.has_gradients()):
- return
- assert dest_buffer.is_pinned()
- assert parameter.numel() <= dest_buffer.numel()
- parameter.grad = dest_buffer.narrow(0, 0, parameter.numel())
- if swap_info.swapped_gradients:
- self._swap_in_pinned_gradients(aio_handle, parameter, parameter.grad)
- if swap_info.unswapped_gradients:
- self._retrieve_unswapped_grad_partitions(swap_info=swap_info,
- dest_buffer=parameter.grad)
|