123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526 |
- """
- Copyright 2020 The Microsoft DeepSpeed Team.
- Licensed under the MIT license.
- Functionality of swapping tensors to/from (NVMe) storage devices.
- """
- import os
- import torch
- from deepspeed.utils.logging import logger
- from deepspeed.runtime.zero.offload_constants import *
- from deepspeed.runtime.swap_tensor.constants import *
- from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, \
- MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers, get_sized_buffer
- from deepspeed.runtime.swap_tensor.utils import SwapBufferManager, SwapBufferPool
- class FlattenedTensorSwapInfo(object):
- def __init__(self, path, length, offset):
- self.path = path
- self.offset = offset
- self.length = length
- class OptimizerStateSwapInfo(object):
- def __init__(self, parameter, numel, base_folder):
- self.tensors = []
- self.param_id = id(parameter)
- self.swap_folder = base_folder
- self.swap_paths = []
- self.swapped_gradients = {}
- self.unswapped_gradients = {}
- self.tensor_numel = numel
- self.tensor_dtype = parameter.dtype
- self.tensor_device = parameter.device
- self.has_state_tensors = False
- self._add_tensors([parameter])
- def numel(self):
- return self.tensor_numel
- def has_gradients(self):
- return self.swapped_gradients or self.unswapped_gradients
- def _add_tensors(self, tensor_list):
- for t in tensor_list:
- self.tensors.append(t)
- self.swap_paths.append(os.path.join(self.swap_folder, f'{id(t)}.tensor.swp'))
- def add_state_tensors(self, tensor_list):
- self.has_state_tensors = True
- self._add_tensors(tensor_list)
- def device(self):
- return self.tensor_device
- def dtype(self):
- return self.tensor_dtype
- def release_memory(self):
- for tensor in self.tensors:
- tensor.data = torch.Tensor()
- def get_or_create_gradient_paths(self, offsets, lengths):
- gradient_paths = []
- for offset, length in zip(offsets, lengths):
- if not offset in self.swapped_gradients.keys():
- path = os.path.join(
- self.swap_folder,
- f'{self.param_id}_gradient_{offset}_{length}.tensor.swp')
- self.swapped_gradients[offset] = FlattenedTensorSwapInfo(
- path,
- length,
- offset)
- gradient_paths.append(self.swapped_gradients[offset].path)
- return gradient_paths
- def set_swap_buffers(self, buffers):
- compute_lengths = [self.numel()] * len(self.tensors)
- compute_buffers = get_sized_buffers(buffers, compute_lengths)
- for t, buffer in zip(self.tensors, compute_buffers):
- t.data = buffer.data
- def get_swap_gradient_buffers(self, swap_buffer):
- assert self.numel() <= swap_buffer.numel()
- return [
- swap_buffer.narrow(0,
- grad.offset,
- grad.length) for grad in self.swapped_gradients.values()
- ]
- def get_swap_gradient_paths(self):
- return [grad.path for grad in self.swapped_gradients.values()]
- def get_unpinned_state_tensors(self):
- return [t for t in self.tensors if not t.is_pinned()]
- def read_unswapped_gradients(self, dest_buffer):
- num_elem_count = 0
- for offset, grad_partition in self.unswapped_gradients.items():
- dst_tensor = dest_buffer.narrow(0, offset, grad_partition.numel())
- dst_tensor.data.copy_(grad_partition.data)
- num_elem_count += grad_partition.numel()
- return num_elem_count
- def release_unswapped_gradients(self):
- self.unswapped_gradients = {}
- SWAPPER_DEBUG_MODE = False
- SWAP_OUT_GRADIENT_TIMER = 'swap_out_gradient'
- class OptimizerSwapper(object):
- def __init__(self,
- swap_config,
- aio_config,
- base_folder,
- optimizer,
- largest_numel,
- device,
- dtype,
- timers):
- self.swap_config = swap_config
- self.aio_config = aio_config
- # NVMe swap management
- self.swap_params_info = {}
- self.swap_element_size = torch.tensor([], dtype=dtype).element_size()
- self.swap_folder = os.path.join(base_folder,
- 'optimizer',
- f'rank{torch.distributed.get_rank()}')
- os.makedirs(self.swap_folder, exist_ok=True)
- self.optimizer = optimizer
- # Read/Write alignment for each thread during Intra-request parallelism
- self.min_aio_bytes = max(MIN_AIO_BYTES, aio_config[AIO_BLOCK_SIZE])
- self.aligned_bytes = AIO_ALIGNED_BYTES * aio_config[AIO_THREAD_COUNT]
- self.numel_alignment = self.aligned_bytes // self.swap_element_size
- # Swap buffer management
- self.largest_numel = self._io_aligned_numel(largest_numel)
- self.dtype = dtype
- self.swap_buffer_manager = SwapBufferManager(
- num_elems=self.largest_numel,
- count=swap_config[OFFLOAD_OPTIMIZER_BUFFER_COUNT],
- dtype=dtype)
- # Timers
- self.timers = timers
- self.timer_names = set()
- # Print exclusion list
- self.print_exclude_list = [
- 'optimizer',
- 'swap_buffer_manager',
- 'swap_params_info',
- 'timers',
- 'timer_names',
- ]
- def swappable_tensor(self, param=None, numel=None):
- assert param is not None or numel is not None, "Either param or numel must be provided"
- if param is not None:
- return self.min_aio_bytes <= (param.numel() * self.swap_element_size)
- return self.min_aio_bytes <= (numel * self.swap_element_size)
- def init_timers(self):
- self.timer_names = set()
- def log_timers(self):
- if self.timer_names:
- self._log_timers(list(self.timer_names), force=True)
- def pre_backward(self):
- self.init_timers()
- def post_backward(self):
- pass
- def _flush_gradient_swapper(self, gradient_swapper):
- if gradient_swapper.has_buffers():
- self._start_timer(SWAP_OUT_GRADIENT_TIMER)
- pinned_buffers = gradient_swapper.release_buffers()
- self.swap_buffer_manager.free(pinned_buffers)
- self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
- self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
- self.timer_names.update(gradient_swapper.get_timer_names())
- def _swap_out_gradients(self,
- parameter,
- gradient_offsets,
- gradient_tensors,
- gradient_swapper):
- if not id(parameter) in self.swap_params_info.keys():
- return
- swap_info = self.swap_params_info[id(parameter)]
- swappable_tensors = []
- swappable_offsets = []
- swappable_lengths = []
- aligned_gradients, aligned_offsets = self._adjust_for_misaligned_lengths(
- tensors=gradient_tensors,
- offsets=gradient_offsets
- )
- self._start_timer(SWAP_OUT_GRADIENT_TIMER)
- for tensor, offset in zip(aligned_gradients, aligned_offsets):
- if not self.swappable_tensor(param=tensor):
- swap_info.unswapped_gradients[offset] = tensor
- continue
- swappable_tensors.append(tensor)
- swappable_offsets.append(offset)
- swappable_lengths.append(tensor.numel())
- if len(swappable_tensors) > 0:
- if not gradient_swapper.has_buffers():
- pinned_buffers = self.swap_buffer_manager.allocate_all(
- num_elems=self.largest_numel,
- dtype=self.dtype)
- gradient_swapper.add_buffers(pinned_buffers)
- swappable_paths = swap_info.get_or_create_gradient_paths(
- swappable_offsets,
- swappable_lengths)
- gradient_swapper.swap_out_tensors(tensor_list=swappable_tensors,
- path_list=swappable_paths)
- self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
- self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
- def _initialize_from_swapped_fp16_params(self,
- aio_handle,
- fp16_partitions_info,
- fp16_num_elems,
- fp16_pinned_buffers,
- fp32_parameters):
- assert len(fp32_parameters) == len(fp16_partitions_info)
- assert len(fp32_parameters) == len(fp16_num_elems)
- assert all([buffer.is_pinned() for buffer in fp16_pinned_buffers])
- fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters,
- num_elems=fp16_num_elems)
- fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(
- num_elems=self.largest_numel,
- dtype=self.dtype)
- fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers]
- assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \
- f"numel of fp16 buffers {fp16_buffer_numel} is too small for initializing fp32 params {self.largest_numel}"
- fp32_swap_buffers = SwapBufferPool(fp32_pinned_buffers)
- fp16_swap_buffers = SwapBufferPool(fp16_pinned_buffers)
- curr_index = 0
- while curr_index < len(fp32_parameters):
- fp16_pinned_tensors = self._swap_in_fp16_params(
- aio_handle=aio_handle,
- fp16_num_elems=fp16_num_elems[curr_index:],
- fp16_partitions_info=fp16_partitions_info[curr_index:],
- fp16_swap_buffers=fp16_swap_buffers)
- if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
- for i, tensor in enumerate(fp16_pinned_tensors):
- true_index = curr_index + i
- logger.info(
- f'swap_in_fp16_param: fp32_id = {id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}'
- )
- swap_out_count = self._swap_out_fp16_params(
- aio_handle=aio_handle,
- fp32_swap_paths=fp32_swap_paths[curr_index:],
- fp32_swap_buffers=fp32_swap_buffers,
- fp16_pinned_tensors=fp16_pinned_tensors)
- assert swap_out_count == len(fp16_pinned_tensors), \
- f"{swap_out_count} does not match {len(fp16_pinned_tensors)}"
- fp16_swap_buffers.reset()
- fp32_swap_buffers.reset()
- curr_index += swap_out_count
- self.swap_buffer_manager.free(fp32_pinned_buffers)
- def _swap_in_fp16_params(self,
- aio_handle,
- fp16_num_elems,
- fp16_partitions_info,
- fp16_swap_buffers):
- assert len(fp16_num_elems) > 0
- swapped_fp16_tensors = []
- swap_tensors = []
- swap_paths = []
- unswapped_srcs = []
- unswapped_dsts = []
- for i, numel in enumerate(fp16_num_elems):
- pinned_tensor, _ = fp16_swap_buffers.allocate_tensor(numel, None, numel)
- if pinned_tensor is None:
- break
- swapped_fp16_tensors.append(pinned_tensor)
- offset = 0
- for tensor, partition_numel, partition_path in fp16_partitions_info[i]:
- dst_tensor = pinned_tensor.narrow(0, offset, partition_numel)
- if partition_path is None:
- unswapped_srcs.append(tensor)
- unswapped_dsts.append(dst_tensor)
- else:
- swap_paths.append(partition_path)
- swap_tensors.append(dst_tensor)
- offset += partition_numel
- assert len(swapped_fp16_tensors) + len(unswapped_srcs) > 0
- ret = swap_in_tensors(aio_handle, swap_tensors, swap_paths)
- for src, dst in zip(unswapped_srcs, unswapped_dsts):
- dst.data.copy_(src.data)
- assert len(swap_tensors) == aio_handle.wait()
- return swapped_fp16_tensors
- def _swap_out_fp16_params(self,
- aio_handle,
- fp32_swap_paths,
- fp32_swap_buffers,
- fp16_pinned_tensors):
- assert len(fp16_pinned_tensors) <= len(fp32_swap_paths)
- swap_out_count = 0
- for i, fp16_tensor in enumerate(fp16_pinned_tensors):
- if not fp32_swap_buffers.has_space(fp16_tensor.numel()):
- fp32_swap_buffers.swap_out(aio_handle)
- fp32_swap_buffers.reset()
- pinned_tensor, _ = fp32_swap_buffers.insert_tensor(
- fp16_tensor,
- fp32_swap_paths[i],
- self._io_aligned_numel(fp16_tensor.numel())
- )
- assert pinned_tensor is not None
- swap_out_count += 1
- if len(fp32_swap_buffers.get_swap_tensors()) > 0:
- fp32_swap_buffers.swap_out(aio_handle)
- return swap_out_count
- def _initialize_parameters(self, parameters, src_tensors, aio_handle):
- assert len(parameters) == len(src_tensors)
- swap_paths = self._get_swap_paths(parameters=parameters,
- num_elems=[src.numel() for src in src_tensors])
- SWAP_INIT_TIMER = "swap_init_write"
- self._start_timer(SWAP_INIT_TIMER)
- pinned_buffers = self.swap_buffer_manager.allocate_all(
- num_elems=self.largest_numel,
- dtype=self.dtype)
- assert pinned_buffers is not None
- self._swap_out_unpinned_tensors(aio_handle=aio_handle,
- unpinned_tensors=src_tensors,
- dest_paths=swap_paths,
- pinned_buffers=pinned_buffers)
- if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
- for i, tensor in enumerate(src_tensors):
- logger.info(
- f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
- )
- self.swap_buffer_manager.free(pinned_buffers)
- self._stop_timer(SWAP_INIT_TIMER)
- self._log_timers([SWAP_INIT_TIMER])
- def _get_swap_paths(self, parameters, num_elems):
- swap_info_list = [
- self._create_param_swap_info(parameter=p,
- numel=numel) \
- for p, numel in zip(parameters, num_elems)
- ]
- assert len(swap_info_list) == len(num_elems)
- swap_paths = [info.swap_paths[0] for info in swap_info_list]
- return swap_paths
- def _swap_out_unpinned_tensors(self,
- aio_handle,
- unpinned_tensors,
- dest_paths,
- pinned_buffers):
- swap_buffer_count = len(pinned_buffers)
- unpinned_tensor_count = len(unpinned_tensors)
- for i in range(0, unpinned_tensor_count, swap_buffer_count):
- swap_tensor_count = min((unpinned_tensor_count - i), swap_buffer_count)
- src_tensors = unpinned_tensors[i:(i + swap_tensor_count)]
- compute_lengths = [t.numel() for t in src_tensors]
- compute_buffers = get_sized_buffers(pinned_buffers, compute_lengths)
- for dst, src in zip(compute_buffers, src_tensors):
- dst.data.copy_(src.data)
- swap_lengths = [self._io_aligned_numel(t.numel()) for t in src_tensors]
- swap_buffers = get_sized_buffers(pinned_buffers, swap_lengths)
- swap_paths = dest_paths[i:(i + swap_tensor_count)]
- swap_out_tensors(aio_handle, swap_buffers, swap_paths)
- assert aio_handle.wait() == swap_tensor_count
- def _adjust_for_misaligned_lengths(self, tensors, offsets):
- new_tensors = []
- new_offsets = []
- for orig_tensor, orig_offset in zip(tensors, offsets):
- if not self.swappable_tensor(param=orig_tensor):
- new_tensors.append(orig_tensor)
- new_offsets.append(orig_offset)
- continue
- remainder = orig_tensor.numel() % self.numel_alignment
- if remainder == 0:
- new_tensors.append(orig_tensor)
- new_offsets.append(orig_offset)
- continue
- # Split into two by making remainder a tensor
- aligned_length = (orig_tensor.numel() //
- self.numel_alignment) * self.numel_alignment
- new_tensors.append(orig_tensor.narrow(0, 0, aligned_length))
- new_offsets.append(orig_offset)
- # remainder tensor
- new_tensors.append(orig_tensor.narrow(0, aligned_length, remainder))
- new_offsets.append(orig_offset + aligned_length)
- return new_tensors, new_offsets
- def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
- UNSWAPPED_READ_GRADIENTS = 'unswapped_read_gradients'
- self._start_timer(UNSWAPPED_READ_GRADIENTS)
- tensor_count = len(swap_info.unswapped_gradients)
- num_elem_count = swap_info.read_unswapped_gradients(dest_buffer)
- self._stop_timer(UNSWAPPED_READ_GRADIENTS)
- self._log_timers([UNSWAPPED_READ_GRADIENTS])
- # It should be safe to discard unswapped gradient partitions
- swap_info.release_unswapped_gradients()
- if SWAPPER_DEBUG_MODE:
- logger.info(
- f'optimizer_retrieve_unswapped_gradients: param={swap_info.param_id} tensor_count={tensor_count} elem_count={num_elem_count}'
- )
- def _get_state_tensors(self, parameter):
- if not parameter in self.optimizer.state:
- return []
- tensor_list = []
- for value in self.optimizer.state[parameter].values():
- if torch.is_tensor(value):
- tensor_list.append(value)
- return tensor_list
- def _update_param_state_info(self, swap_info, parameter):
- if not swap_info.has_state_tensors:
- state_tensors = self._get_state_tensors(parameter)
- if state_tensors:
- swap_info.add_state_tensors(state_tensors)
- def _create_param_swap_info(self, parameter, numel):
- param_id = id(parameter)
- assert not param_id in self.swap_params_info
- self.swap_params_info[param_id] = OptimizerStateSwapInfo(
- parameter=parameter,
- numel=numel,
- base_folder=self.swap_folder)
- swap_info = self.swap_params_info[param_id]
- self._update_param_state_info(swap_info, parameter)
- return swap_info
- def _get_param_swap_info(self, parameter):
- param_id = id(parameter)
- swap_info = self.swap_params_info.get(param_id, None)
- if swap_info is not None:
- self._update_param_state_info(swap_info, parameter)
- return swap_info
- def _start_timer(self, name):
- if self.timers:
- self.timers(name).start()
- def _stop_timer(self, name):
- if self.timers:
- self.timers(name).stop()
- def _log_timers(self, name_list, force=False):
- if self.timers and (SWAPPER_DEBUG_MODE or force):
- self.timers.log(name_list)
- def _io_aligned_numel(self, numel):
- remainder = numel % self.numel_alignment
- return numel if remainder == 0 else (numel + self.numel_alignment - remainder)
|