contiguous_memory_allocator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import torch
  2. def print_rank_0(message):
  3. if torch.distributed.get_rank() == 0:
  4. print(message)
  5. class ContiguousMemoryAllocator(object):
  6. def __init__(self, size, dtype, device):
  7. self.buffer = torch.zeros(size, dtype=dtype, device=device)
  8. #address to contiguous size available
  9. self.contiguous_sizes = {}
  10. self.contiguous_sizes[0] = size
  11. #tensor id to its address
  12. self.tensor_addresses = {}
  13. #tensor address to its size
  14. self.tensor_sizes = {}
  15. #tensor address to ids
  16. self.tensor_ids = {}
  17. #id to tensors
  18. self.tensor_map = {}
  19. #id to params. Maps each tensor buffer to list of parameters that uses it
  20. self.id_to_params = {}
  21. self.total_size = size
  22. self.total_free = size
  23. self.largest_contiguous = size
  24. self.max_allocated = 0
  25. self.count = 0
  26. #create a tensor of size from the pre-allocated buffer
  27. #if not enough free space will fail
  28. #if not enough contiguous space, will defragment and allocate
  29. def allocate_tensor(self, size):
  30. free_before = self.total_free
  31. assert size <= self.total_free, "Not enough memory in buffer. Allocation failed"
  32. if self.largest_contiguous < size:
  33. print_rank_0("Needs defragmentation to allocate. Before Defragmentation:")
  34. self.print_allocation(resolution=100)
  35. self._defragment_memory()
  36. #set the param data to the new tensor buffer locations
  37. self._reset_param_data()
  38. print_rank_0("After defragmentation:")
  39. self.print_allocation(resolution=100)
  40. self.total_free = self.total_free - size
  41. allocated = self.total_size - self.total_free
  42. if allocated > self.max_allocated:
  43. self.max_allocated = allocated
  44. tensor_address = self._get_new_tensor_address(size)
  45. ret_tensor = self._get_new_tensor(tensor_address, size)
  46. print_rank_0(
  47. f"Free before allocation {free_before}. Allocating {size}. Free after allocation {self.total_free}. Max allocated {self.max_allocated}"
  48. )
  49. assert self.total_free + size == free_before, "Allocation bookkeeping error"
  50. return ret_tensor
  51. #assigns the tensor data to the param data and keeps track of the assignment
  52. #any change the the underlying buffer from defragmentation will cause a
  53. #reassignment of the param data
  54. def assign_to_param(self, tensor, param, numel, shape):
  55. tensor_id = id(tensor)
  56. assert tensor_id in self.tensor_map.keys(), "No such tensor allocated by the allocator."
  57. assert tensor.numel() >= numel, "Assert tensor buffer does is not large enough"
  58. assert not tensor_id in self.id_to_params.keys(), "This tensor has already been assigned to a param"
  59. self.id_to_params[tensor_id] = [param]
  60. replicated_tensor = tensor.narrow(0, 0, numel).view(shape)
  61. param.data = replicated_tensor.data
  62. param.contiguous_tensor_id = tensor_id
  63. #deletes the tensor and frees up the underlying buffer
  64. def release_tensor(self, tensor):
  65. free_before = self.total_free
  66. tensor_id = id(tensor)
  67. tensor_size = tensor.numel()
  68. self._release_tensor(tensor_id)
  69. self._unassign_params(tensor_id)
  70. self.total_free += tensor_size
  71. print_rank_0(
  72. f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
  73. )
  74. assert self.total_free - tensor_size == free_before, "Release bookeeping error"
  75. def release_tensor_with_id(self, tensor_id):
  76. free_before = self.total_free
  77. assert tensor_id in self.tensor_map.keys(), "Invalid tensor id"
  78. tensor = self.tensor_map[tensor_id]
  79. tensor_size = tensor.numel()
  80. self._release_tensor(tensor_id)
  81. self._unassign_params(tensor_id)
  82. self.total_free += tensor_size
  83. print_rank_0(
  84. f"Free before release {free_before}. Released {tensor.numel()}. Total free after {self.total_free}."
  85. )
  86. assert self.total_free - tensor_size == free_before, "Release bookeeping error"
  87. #shows the current memory allocation at specified resolution
  88. def print_allocation(self, resolution=200):
  89. total_size = self.buffer.numel() * 1.0
  90. empty = []
  91. for addr, size in self.contiguous_sizes.items():
  92. start = int(addr * resolution / total_size)
  93. end = int((addr + size) * resolution / total_size)
  94. empty.extend(range(start, end))
  95. s = ''
  96. for i in range(resolution):
  97. s += '.' if i in empty else '|'
  98. print_rank_0(s)
  99. def max_allocated(self):
  100. return self.max_allocated
  101. #to be called after defragmentation that moves the tensor buffers
  102. #this call reassigns the data of all the parameters using the tensor buffers
  103. def _reset_param_data(self):
  104. for id, tensor in self.tensor_map.items():
  105. for param in self.id_to_params[id]:
  106. param.data = tensor.narrow(0,
  107. 0,
  108. param.numel()).view(param.data.shape).data
  109. def _unassign_params(self, tensor_id):
  110. if tensor_id in self.id_to_params.keys():
  111. del self.id_to_params[tensor_id]
  112. def _release_tensor(self, tensor_id):
  113. assert tensor_id in self.tensor_addresses, f"Tensor id {tensor_id} not found"
  114. address = self.tensor_addresses[tensor_id]
  115. contiguous_size = self.tensor_map[tensor_id].numel()
  116. del self.tensor_addresses[tensor_id]
  117. del self.tensor_ids[address]
  118. del self.tensor_map[tensor_id]
  119. del self.tensor_sizes[address]
  120. self._consolidate_address(address, contiguous_size)
  121. self.largest_contiguous = self._largest_contiguous()
  122. def _consolidate_address(self, address, contiguous_size):
  123. #consolidate next buffer
  124. end_address = address + contiguous_size
  125. if end_address in self.contiguous_sizes:
  126. contiguous_size += self.contiguous_sizes[end_address]
  127. del self.contiguous_sizes[end_address]
  128. #consolidate previous buffer
  129. for addr, size in self.contiguous_sizes.items():
  130. if addr + size == address:
  131. del self.contiguous_sizes[addr]
  132. contiguous_size += size
  133. address = addr
  134. break
  135. self.contiguous_sizes[address] = contiguous_size
  136. def _defragment_memory(self):
  137. empty_addresses = sorted(self.contiguous_sizes.keys())
  138. tensor_addresses = sorted(self.tensor_addresses.values())
  139. tensor_index = 0
  140. while tensor_index < len(tensor_addresses):
  141. empty_addr = empty_addresses[0]
  142. empty_size = self.contiguous_sizes[empty_addr]
  143. tensor_addr = tensor_addresses[tensor_index]
  144. tensor_size = self.tensor_sizes[tensor_addr]
  145. tensor_id = self.tensor_ids[tensor_addr]
  146. tensor = self.tensor_map[self.tensor_ids[tensor_addr]]
  147. assert tensor_size == tensor.numel(), \
  148. "Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "
  149. assert empty_addr != tensor_addr, \
  150. f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}"
  151. if empty_addr < tensor_addr:
  152. if empty_size >= tensor_size:
  153. dest_buffer = self.buffer.narrow(0, empty_addr, tensor_size)
  154. src_buffer = self.buffer.narrow(0, tensor_addr, tensor_size)
  155. dest_buffer.data.copy_(src_buffer.data)
  156. else:
  157. #print_rank_0(f'empty addr : {empty_addr}, empty size {empty_size} tensor addr {tensor_addr} tensor size {tensor_size}')
  158. src_addr = tensor_addr
  159. dest_addr = empty_addr
  160. while src_addr < (tensor_addr + tensor_size):
  161. copy_size = min(empty_size, tensor_addr + tensor_size - src_addr)
  162. dest_buffer = self.buffer.narrow(0, dest_addr, copy_size)
  163. src_buffer = self.buffer.narrow(0, src_addr, copy_size)
  164. dest_buffer.data.copy_(src_buffer.data)
  165. src_addr += copy_size
  166. dest_addr += copy_size
  167. self._replace_old_address_with_new(tensor_id, empty_addr)
  168. tensor_index += 1
  169. else:
  170. tensor_index += 1
  171. empty_addresses = sorted(self.contiguous_sizes.keys())
  172. def _replace_old_address_with_new(self, tensor_id, new_address):
  173. tensor = self.tensor_map[tensor_id]
  174. tensor_size = tensor.numel()
  175. tensor.data = self.buffer.narrow(0, new_address, tensor_size).data
  176. self._release_tensor(tensor_id)
  177. self._mark_as_occupied(new_address, tensor_size)
  178. self.tensor_ids[new_address] = tensor_id
  179. self.tensor_map[tensor_id] = tensor
  180. self.tensor_addresses[tensor_id] = new_address
  181. self.tensor_sizes[new_address] = tensor_size
  182. def _get_new_tensor_address(self, size):
  183. tensor_address = None
  184. for address, contiguous_size in self.contiguous_sizes.items():
  185. if contiguous_size >= size and \
  186. (tensor_address is None or \
  187. contiguous_size < self.contiguous_sizes[tensor_address]):
  188. tensor_address = address
  189. assert tensor_address is not None, "address cannot be None"
  190. return tensor_address
  191. def _get_new_tensor(self, address, size):
  192. available_contiguous_size = self.contiguous_sizes[address]
  193. assert size <= available_contiguous_size, \
  194. f"Tensor numel {size} is large than available contiguous size {available_contiguous_size}"
  195. self.count += 1
  196. new_tensor = self.buffer.narrow(0, address, size)
  197. tensor_id = id(new_tensor)
  198. self.tensor_addresses[tensor_id] = address
  199. self.tensor_sizes[address] = size
  200. self.tensor_ids[address] = tensor_id
  201. self.tensor_map[tensor_id] = new_tensor
  202. self._mark_as_occupied(address, size)
  203. return new_tensor
  204. def _largest_contiguous(self):
  205. if len(self.contiguous_sizes) > 0:
  206. return max([size for _, size in self.contiguous_sizes.items()])
  207. else:
  208. return 0
  209. def _mark_as_occupied(self, address, size):
  210. available_contiguous_size = self.contiguous_sizes[address]
  211. del self.contiguous_sizes[address]
  212. if available_contiguous_size != size:
  213. self.contiguous_sizes[address + size] = available_contiguous_size - size
  214. self.largest_contiguous = self._largest_contiguous()