123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- """
- Functionality of swapping tensors to/from (NVMe) storage devices.
- """
- import torch
- from deepspeed.utils.logging import logger
- from deepspeed.accelerator import get_accelerator
- from deepspeed import comm as dist
- MIN_AIO_BYTES = 1024**2
- AIO_ALIGNED_BYTES = 1024
- def swap_in_tensors(swap_handle, tensor_buffers, swap_paths):
- for buffer, path in zip(tensor_buffers, swap_paths):
- assert (swap_handle.async_pread(buffer, path) == 0)
- def swap_out_tensors(swap_handle, tensor_buffers, swap_paths):
- for buffer, path in zip(tensor_buffers, swap_paths):
- assert (swap_handle.async_pwrite(buffer, path) == 0)
- def print_object(obj, name, exclude_list=[]):
- logger.info('{}:'.format(name))
- for arg in sorted(vars(obj)):
- if not arg in exclude_list:
- dots = '.' * (29 - len(arg))
- logger.info(' {} {} {}'.format(arg, dots, getattr(obj, arg)))
- class SwapBuffer(object):
- def __init__(self, buffer):
- self.buffer = buffer
- self.reset()
- def reset(self):
- self.offset = 0
- self.swap_tensors = {}
- self.compute_tensors = {}
- self.swap_paths = {}
- self.num_elem = 0
- def insert_tensor(self, tensor, swap_path, aligned_numel):
- swap_tensor, compute_tensor = self.allocate_tensor(swap_path, tensor.numel(), aligned_numel)
- compute_tensor.data.copy_(tensor.data)
- return swap_tensor, compute_tensor
- def allocate_tensor(self, swap_path, numel, aligned_numel):
- assert self.has_space(aligned_numel)
- assert not self.offset in self.swap_tensors
- allocate_offset = self.offset
- swap_tensor = self.buffer.narrow(0, allocate_offset, aligned_numel)
- dest_tensor = swap_tensor.narrow(0, 0, numel)
- self.swap_tensors[allocate_offset] = swap_tensor
- self.compute_tensors[allocate_offset] = dest_tensor
- self.swap_paths[allocate_offset] = swap_path
- self.offset += aligned_numel
- self.num_elem += numel
- return self.swap_tensors[allocate_offset], self.compute_tensors[allocate_offset]
- def has_space(self, numel):
- return (self.offset + numel) <= self.buffer.numel()
- def get_swap_tensors(self):
- return [tensor for tensor in self.swap_tensors.values()]
- def get_swap_paths(self):
- return [path for path in self.swap_paths.values()]
- def get_compute_tensors(self):
- return [tensor for tensor in self.compute_tensors.values()]
- def get_num_elem(self):
- return self.num_elem
- def get_swap_tensor(self, offset):
- return self.swap_tensors.get(offset, None)
- def get_compute_tensor(self, offset):
- return self.compute_tensors.get(offset, None)
- def get_swap_path(self, offset):
- return self.swap_paths(offset, None)
- class SwapBufferPool(object):
- def __init__(self, buffers):
- assert all([get_accelerator().is_pinned(buf) for buf in buffers])
- self.buffers = [SwapBuffer(buf) for buf in buffers]
- self.current_index = 0
- def reset(self):
- self.current_index = 0
- for buffer in self.buffers:
- buffer.reset()
- def allocate_tensor(self, numel, swap_path, aligned_numel):
- if self.has_space(aligned_numel):
- swap_tensor, compute_tensor = self._get_current_buffer().allocate_tensor(swap_path, numel, aligned_numel)
- return swap_tensor, compute_tensor
- return None, None
- def insert_tensor(self, tensor, swap_path, aligned_numel):
- if self.has_space(aligned_numel):
- swap_tensor, compute_tensor = self._get_current_buffer().insert_tensor(tensor, swap_path, aligned_numel)
- return swap_tensor, compute_tensor
- return None, None
- def get_swap_tensors(self):
- swap_tensors = []
- for buffer in self._get_used_buffers():
- swap_tensors += buffer.get_swap_tensors()
- return swap_tensors
- def get_swap_paths(self):
- swap_paths = []
- for buffer in self._get_used_buffers():
- swap_paths += buffer.get_swap_paths()
- return swap_paths
- def get_compute_tensors(self):
- compute_tensors = []
- for buffer in self._get_used_buffers():
- compute_tensors += buffer.get_compute_tensors()
- return compute_tensors
- def has_space(self, numel):
- if self._get_current_buffer().has_space(numel):
- return True
- if self.current_index == len(self.buffers) - 1:
- return False
- self.current_index += 1
- return self._get_current_buffer().has_space(numel)
- def swap_out(self, aio_handle, async_op=False):
- swap_tensors = self.get_swap_tensors()
- swap_paths = self.get_swap_paths()
- assert all([p is not None for p in swap_paths])
- swap_out_tensors(aio_handle, swap_tensors, swap_paths)
- if not async_op:
- assert len(swap_tensors) == aio_handle.wait()
- def swap_in(self, aio_handle, async_op=False):
- swap_tensors = self.get_swap_tensors()
- swap_paths = self.get_swap_paths()
- assert all([p is not None for p in swap_paths])
- swap_in_tensors(aio_handle, swap_tensors, swap_paths)
- if not async_op:
- assert len(swap_tensors) == aio_handle.wait()
- def _get_current_buffer(self):
- return self.buffers[self.current_index]
- def _get_used_buffers(self):
- return self.buffers[:self.current_index + 1]
- class SwapBufferManager(object):
- def __init__(self, num_elems, count, dtype):
- self.num_elems = num_elems
- self.count = count
- self.dtype = dtype
- self.all_buffers = [
- get_accelerator().pin_memory(torch.zeros(num_elems, device='cpu', dtype=dtype), align_bytes=0)
- for _ in range(count)
- ]
- self.free_buffer_index = [i for i in range(count)]
- self.used_buffer_index = {}
- self.gigabytes = (self.all_buffers[0].element_size() * num_elems * count) / (1024**3)
- if dist.get_rank() == 0:
- exclude_list = ['all_buffers']
- print_object(obj=self, name='SwapBufferManager', exclude_list=exclude_list)
- def allocate(self, num_elems, count, dtype):
- assert dtype == self.dtype
- assert num_elems <= self.num_elems
- if count > len(self.free_buffer_index):
- return None
- used_indices = self.free_buffer_index[-count:]
- self.free_buffer_index = self.free_buffer_index[:-count]
- buffers = []
- for i in used_indices:
- tmp_buffer = self.all_buffers[i].narrow(0, 0, num_elems)
- buffers.append(tmp_buffer)
- self.used_buffer_index[id(tmp_buffer)] = i
- return buffers
- def allocate_all(self, num_elems, dtype):
- return self.allocate(num_elems=num_elems, count=len(self.free_buffer_index), dtype=dtype)
- def free(self, buffers):
- buffer_ids = []
- for buf in buffers:
- buffer_ids.append(id(buf))
- assert all([b_id in self.used_buffer_index for b_id in buffer_ids])
- for b_id in buffer_ids:
- self.free_buffer_index.append(self.used_buffer_index[b_id])
- del (self.used_buffer_index[b_id])
- def get_sized_buffer(buffer, num_elems):
- assert num_elems <= buffer.numel(), \
- f'num_elems {num_elems} > buffer {buffer.numel()}'
- return buffer.narrow(0, 0, num_elems) if num_elems < buffer.numel() else buffer
- def get_sized_buffers(buffer_list, num_elems_list):
- swap_buffers = [
- get_sized_buffer(buffer, num_elems) \
- for buffer, num_elems in zip(buffer_list, num_elems_list)
- ]
- return swap_buffers
|