contiguous_memory_allocator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from deepspeed import comm as dist
  4. def print_rank_0(message):
  5. if dist.get_rank() == 0:
  6. print(message)
  7. class ContiguousMemoryAllocator(object):
  8. def __init__(self, size, dtype, device):
  9. self.buffer = torch.zeros(size, dtype=dtype, device=device)
  10. #address to contiguous size available
  11. self.contiguous_sizes = {}
  12. self.contiguous_sizes[0] = size
  13. #tensor id to its address
  14. self.tensor_addresses = {}
  15. #tensor address to its size
  16. self.tensor_sizes = {}
  17. #tensor address to ids
  18. self.tensor_ids = {}
  19. #id to tensors
  20. self.tensor_map = {}
  21. #id to params. Maps each tensor buffer to list of parameters that uses it
  22. self.id_to_params = {}
  23. self.total_size = size
  24. self.total_free = size
  25. self.largest_contiguous = size
  26. self.max_allocated = 0
  27. self.count = 0
  28. #create a tensor of size from the pre-allocated buffer
  29. #if not enough free space will fail
  30. #if not enough contiguous space, will defragment and allocate
  31. def allocate_tensor(self, size):
  32. free_before = self.total_free
  33. assert size <= self.total_free, "Not enough memory in buffer. Allocation failed"
  34. if self.largest_contiguous < size:
  35. print_rank_0("Needs defragmentation to allocate. Before Defragmentation:")
  36. self.print_allocation(resolution=100)
  37. self._defragment_memory()
  38. #set the param data to the new tensor buffer locations
  39. self._reset_param_data()
  40. print_rank_0("After defragmentation:")
  41. self.print_allocation(resolution=100)
  42. self.total_free = self.total_free - size
  43. allocated = self.total_size - self.total_free
  44. if allocated > self.max_allocated:
  45. self.max_allocated = allocated
  46. tensor_address = self._get_new_tensor_address(size)
  47. ret_tensor = self._get_new_tensor(tensor_address, size)
  48. print_rank_0(
  49. f"Free before allocation {free_before}. Allocating {size}. Free after allocation {self.total_free}. Max allocated {self.max_allocated}"
  50. )
  51. assert self.total_free + size == free_before, "Allocation bookkeeping error"
  52. return ret_tensor
  53. #assigns the tensor data to the param data and keeps track of the assignment
  54. #any change the the underlying buffer from defragmentation will cause a
  55. #reassignment of the param data
  56. def assign_to_param(self, tensor, param, numel, shape):
  57. tensor_id = id(tensor)
  58. assert tensor_id in self.tensor_map.keys(), "No such tensor allocated by the allocator."
  59. assert tensor.numel() >= numel, "Assert tensor buffer does is not large enough"
  60. assert not tensor_id in self.id_to_params.keys(), "This tensor has already been assigned to a param"
  61. self.id_to_params[tensor_id] = [param]
  62. replicated_tensor = tensor.narrow(0, 0, numel).view(shape)
  63. param.data = replicated_tensor.data
  64. param.contiguous_tensor_id = tensor_id
  65. #deletes the tensor and frees up the underlying buffer
  66. def release_tensor(self, tensor):
  67. free_before = self.total_free
  68. tensor_id = id(tensor)
  69. tensor_size = tensor.numel()
  70. self._release_tensor(tensor_id)
  71. self._unassign_params(tensor_id)
  72. self.total_free += tensor_size
  73. print_rank_0(
  74. f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
  75. )
  76. assert self.total_free - tensor_size == free_before, "Release bookkeeping error"
  77. def release_tensor_with_id(self, tensor_id):
  78. free_before = self.total_free
  79. assert tensor_id in self.tensor_map.keys(), "Invalid tensor id"
  80. tensor = self.tensor_map[tensor_id]
  81. tensor_size = tensor.numel()
  82. self._release_tensor(tensor_id)
  83. self._unassign_params(tensor_id)
  84. self.total_free += tensor_size
  85. print_rank_0(
  86. f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
  87. )
  88. assert self.total_free - tensor_size == free_before, "Release bookkeeping error"
  89. #shows the current memory allocation at specified resolution
  90. def print_allocation(self, resolution=200):
  91. total_size = self.buffer.numel() * 1.0
  92. empty = []
  93. for addr, size in self.contiguous_sizes.items():
  94. start = int(addr * resolution / total_size)
  95. end = int((addr + size) * resolution / total_size)
  96. empty.extend(range(start, end))
  97. s = ''
  98. for i in range(resolution):
  99. s += '.' if i in empty else '|'
  100. print_rank_0(s)
  101. def max_allocated(self):
  102. return self.max_allocated
  103. #to be called after defragmentation that moves the tensor buffers
  104. #this call reassigns the data of all the parameters using the tensor buffers
  105. def _reset_param_data(self):
  106. for id, tensor in self.tensor_map.items():
  107. for param in self.id_to_params[id]:
  108. param.data = tensor.narrow(0,
  109. 0,
  110. param.numel()).view(param.data.shape).data
  111. def _unassign_params(self, tensor_id):
  112. if tensor_id in self.id_to_params.keys():
  113. del self.id_to_params[tensor_id]
  114. def _release_tensor(self, tensor_id):
  115. assert tensor_id in self.tensor_addresses, f"Tensor id {tensor_id} not found"
  116. address = self.tensor_addresses[tensor_id]
  117. contiguous_size = self.tensor_map[tensor_id].numel()
  118. del self.tensor_addresses[tensor_id]
  119. del self.tensor_ids[address]
  120. del self.tensor_map[tensor_id]
  121. del self.tensor_sizes[address]
  122. self._consolidate_address(address, contiguous_size)
  123. self.largest_contiguous = self._largest_contiguous()
  124. def _consolidate_address(self, address, contiguous_size):
  125. #consolidate next buffer
  126. end_address = address + contiguous_size
  127. if end_address in self.contiguous_sizes:
  128. contiguous_size += self.contiguous_sizes[end_address]
  129. del self.contiguous_sizes[end_address]
  130. #consolidate previous buffer
  131. for addr, size in self.contiguous_sizes.items():
  132. if addr + size == address:
  133. del self.contiguous_sizes[addr]
  134. contiguous_size += size
  135. address = addr
  136. break
  137. self.contiguous_sizes[address] = contiguous_size
  138. def _defragment_memory(self):
  139. empty_addresses = sorted(self.contiguous_sizes.keys())
  140. tensor_addresses = sorted(self.tensor_addresses.values())
  141. tensor_index = 0
  142. while tensor_index < len(tensor_addresses):
  143. empty_addr = empty_addresses[0]
  144. empty_size = self.contiguous_sizes[empty_addr]
  145. tensor_addr = tensor_addresses[tensor_index]
  146. tensor_size = self.tensor_sizes[tensor_addr]
  147. tensor_id = self.tensor_ids[tensor_addr]
  148. tensor = self.tensor_map[self.tensor_ids[tensor_addr]]
  149. assert tensor_size == tensor.numel(), \
  150. "Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "
  151. assert empty_addr != tensor_addr, \
  152. f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}"
  153. if empty_addr < tensor_addr:
  154. if empty_size >= tensor_size:
  155. dest_buffer = self.buffer.narrow(0, empty_addr, tensor_size)
  156. src_buffer = self.buffer.narrow(0, tensor_addr, tensor_size)
  157. dest_buffer.data.copy_(src_buffer.data)
  158. else:
  159. #print_rank_0(f'empty addr : {empty_addr}, empty size {empty_size} tensor addr {tensor_addr} tensor size {tensor_size}')
  160. src_addr = tensor_addr
  161. dest_addr = empty_addr
  162. while src_addr < (tensor_addr + tensor_size):
  163. copy_size = min(empty_size, tensor_addr + tensor_size - src_addr)
  164. dest_buffer = self.buffer.narrow(0, dest_addr, copy_size)
  165. src_buffer = self.buffer.narrow(0, src_addr, copy_size)
  166. dest_buffer.data.copy_(src_buffer.data)
  167. src_addr += copy_size
  168. dest_addr += copy_size
  169. self._replace_old_address_with_new(tensor_id, empty_addr)
  170. tensor_index += 1
  171. else:
  172. tensor_index += 1
  173. empty_addresses = sorted(self.contiguous_sizes.keys())
  174. def _replace_old_address_with_new(self, tensor_id, new_address):
  175. tensor = self.tensor_map[tensor_id]
  176. tensor_size = tensor.numel()
  177. tensor.data = self.buffer.narrow(0, new_address, tensor_size).data
  178. self._release_tensor(tensor_id)
  179. self._mark_as_occupied(new_address, tensor_size)
  180. self.tensor_ids[new_address] = tensor_id
  181. self.tensor_map[tensor_id] = tensor
  182. self.tensor_addresses[tensor_id] = new_address
  183. self.tensor_sizes[new_address] = tensor_size
  184. def _get_new_tensor_address(self, size):
  185. tensor_address = None
  186. for address, contiguous_size in self.contiguous_sizes.items():
  187. if contiguous_size >= size and \
  188. (tensor_address is None or \
  189. contiguous_size < self.contiguous_sizes[tensor_address]):
  190. tensor_address = address
  191. assert tensor_address is not None, "address cannot be None"
  192. return tensor_address
  193. def _get_new_tensor(self, address, size):
  194. available_contiguous_size = self.contiguous_sizes[address]
  195. assert size <= available_contiguous_size, \
  196. f"Tensor numel {size} is large than available contiguous size {available_contiguous_size}"
  197. self.count += 1
  198. new_tensor = self.buffer.narrow(0, address, size)
  199. tensor_id = id(new_tensor)
  200. self.tensor_addresses[tensor_id] = address
  201. self.tensor_sizes[address] = size
  202. self.tensor_ids[address] = tensor_id
  203. self.tensor_map[tensor_id] = new_tensor
  204. self._mark_as_occupied(address, size)
  205. return new_tensor
  206. def _largest_contiguous(self):
  207. if len(self.contiguous_sizes) > 0:
  208. return max([size for _, size in self.contiguous_sizes.items()])
  209. else:
  210. return 0
  211. def _mark_as_occupied(self, address, size):
  212. available_contiguous_size = self.contiguous_sizes[address]
  213. del self.contiguous_sizes[address]
  214. if available_contiguous_size != size:
  215. self.contiguous_sizes[address + size] = available_contiguous_size - size
  216. self.largest_contiguous = self._largest_contiguous()