accelerated_dag_gpu_microbenchmark.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # coding: utf-8
  2. import logging
  3. import torch
  4. import ray.cloudpickle as pickle
  5. import io
  6. import cupy
  7. import numpy as np
  8. import time
  9. import os
  10. import json
  11. import socket
  12. import ray
  13. from ray.air._internal import torch_utils
  14. import ray.cluster_utils
  15. from ray.dag import InputNode, DAGContext
  16. from ray.util.collective.collective_group import nccl_util
  17. from ray.experimental.channel.torch_tensor_type import TorchTensorType
  18. from ray._private.ray_microbenchmark_helpers import timeit
  19. logger = logging.getLogger(__name__)
  20. SHAPE = None
  21. DTYPE = torch.float16
  22. NUM_ITERS = 10
  23. @ray.remote
  24. class TorchIpcWorker:
  25. def __init__(self):
  26. self.device = torch_utils.get_devices()[0]
  27. def send(self, shape, dtype, value: int):
  28. t = torch.ones(shape, dtype=dtype, device=self.device) * value
  29. if self.device.type == "cuda":
  30. # NOTE(swang): This is needed because the IPC can get sent before
  31. # the value has been written to memory. But somehow the read value
  32. # is still the wrong one?
  33. torch.cuda.synchronize()
  34. h = cupy.cuda.runtime.ipcGetMemHandle(t.data_ptr())
  35. return h
  36. def recv(self, device_ptr, num_bytes, shape, dtype):
  37. h = cupy.cuda.runtime.ipcOpenMemHandle(device_ptr)
  38. m = cupy.cuda.UnownedMemory(h, num_bytes, None)
  39. m_ptr = cupy.cuda.MemoryPointer(m, 0)
  40. tensor = torch.tensor(cupy.ndarray(shape, dtype, m_ptr), device=self.device)
  41. assert tensor.device == self.device
  42. return (tensor[0].item(), tensor.shape, tensor.dtype)
  43. @ray.remote
  44. class TorchTensorWorker:
  45. def __init__(self):
  46. self.device = torch_utils.get_devices()[0]
  47. def send(self, shape, dtype, value: int):
  48. t = torch.ones(shape, dtype=dtype, device=self.device) * value
  49. return t
  50. def recv(self, tensor):
  51. assert tensor.device == self.device
  52. return (tensor[0].item(), tensor.shape, tensor.dtype)
  53. @ray.remote(num_gpus=1)
  54. class NcclWorker:
  55. def __init__(self, rank):
  56. self.rank = rank
  57. def get_node_id(self):
  58. return ray.get_runtime_context().get_node_id()
  59. def init(self, world_size):
  60. from ray.air._internal import torch_utils
  61. self.device = torch_utils.get_devices()[0]
  62. self.world_size = world_size
  63. torch.distributed.init_process_group(
  64. backend="nccl",
  65. world_size=world_size,
  66. rank=self.rank,
  67. )
  68. def _send(self, buf, num_el, rank):
  69. torch.distributed.send(buf, rank)
  70. def _recv(self, buf, num_el, rank):
  71. torch.distributed.recv(buf, rank)
  72. def do_send_recv(self, shape, dtype):
  73. other_rank = (self.rank + 1) % self.world_size
  74. def _run():
  75. if self.rank == 0:
  76. i = np.random.randint(100)
  77. input_buffer = torch.ones(shape, dtype=dtype, device=self.device) * i
  78. self._send(input_buffer, input_buffer.numel(), other_rank)
  79. else:
  80. input_buffer = torch.zeros(shape, dtype=dtype, device=self.device)
  81. self._recv(input_buffer, input_buffer.numel(), other_rank)
  82. torch.cuda.synchronize()
  83. return timeit("exec_nccl_gpu", _run)
  84. def exec_ray_dag(
  85. label, sender, receiver, use_nccl=False, use_adag=True, dynamic_shape=False
  86. ):
  87. # Test torch.Tensor sent between actors.
  88. with InputNode() as inp:
  89. dag = sender.send.bind(SHAPE, DTYPE, inp)
  90. if use_adag:
  91. dag = dag.with_type_hint(
  92. TorchTensorType(
  93. "auto" if dynamic_shape else SHAPE,
  94. "auto" if dynamic_shape else DTYPE,
  95. transport="nccl" if use_nccl else "auto",
  96. )
  97. )
  98. dag = receiver.recv.bind(dag)
  99. if use_adag:
  100. dag = dag.experimental_compile()
  101. def _run():
  102. i = np.random.randint(100)
  103. ref = dag.execute(i)
  104. result = ray.get(ref)
  105. assert result == (i, SHAPE, DTYPE)
  106. else:
  107. def _run():
  108. i = np.random.randint(100)
  109. result = ray.get(dag.execute(i))
  110. assert result == (i, SHAPE, DTYPE)
  111. results = timeit(label, _run)
  112. if use_adag:
  113. dag.teardown()
  114. # Workaround for Ray bug in reusing GPUs too quickly.
  115. # See https://github.com/ray-project/ray/issues/44821.
  116. ray.kill(sender)
  117. ray.kill(receiver)
  118. time.sleep(1)
  119. return results
  120. def exec_ray_dag_ipc(label, sender, receiver, use_nccl=False):
  121. # Test torch.Tensor sent between actors.
  122. with InputNode() as inp:
  123. dag = sender.send.bind(SHAPE, DTYPE, inp)
  124. dag = receiver.recv.bind(
  125. dag,
  126. # torch.float16 has item size of 2 bytes.
  127. SHAPE[0] * 2,
  128. SHAPE,
  129. nccl_util.TORCH_NUMPY_DTYPE_MAP[DTYPE],
  130. )
  131. compiled_dag = dag.experimental_compile(_buffer_size_bytes=int(SHAPE[0] * 3))
  132. # Flag that each run can set if it sees incorrect results.
  133. ok = [True]
  134. def _run():
  135. i = np.random.randint(100)
  136. ref = compiled_dag.execute(i)
  137. result = ray.get(ref)
  138. if result != (i, SHAPE, DTYPE):
  139. ok[0] = False
  140. results = timeit(label, _run)
  141. if not ok[0]:
  142. logger.warning("IPC DAG returned incorrect result")
  143. compiled_dag.teardown()
  144. return results
  145. def _exec_torch_cpu_cpu():
  146. i = np.random.randint(100)
  147. t = torch.ones(SHAPE, dtype=DTYPE) * i
  148. t2 = t.to(copy=True)
  149. assert (t2[0].item(), t2.shape, t2.dtype) == (i, SHAPE, DTYPE)
  150. def _exec_torch_gpu():
  151. i = np.random.randint(100)
  152. device_from = torch.device("cuda:1")
  153. device_to = torch.device("cuda:0")
  154. t = torch.ones(SHAPE, dtype=DTYPE, device=device_from) * i
  155. t2 = t.to(device_to)
  156. torch.cuda.synchronize(device_to)
  157. assert (t2[0].item(), t2.shape, t2.dtype) == (i, SHAPE, DTYPE)
  158. def exec_nccl_gpu(sender_hint, receiver_hint):
  159. workers = [
  160. NcclWorker.options(scheduling_strategy=sender_hint).remote(0),
  161. NcclWorker.options(scheduling_strategy=receiver_hint).remote(1),
  162. ]
  163. # node_id = ray.get(workers[0].get_node_id.remote())
  164. # head_node = [node for node in ray.nodes() if node["NodeID"] == node_id]
  165. # assert len(head_node) == 1
  166. # head_node = head_node[0]
  167. # rank_0_addr = f"{head_node['NodeManagerAddress']}:8888"
  168. ray.get([worker.init.remote(2) for worker in workers])
  169. tasks = [worker.do_send_recv.remote(SHAPE, DTYPE) for worker in workers]
  170. done_refs, _ = ray.wait(tasks, num_returns=1)
  171. results = ray.get(done_refs[0])
  172. # Workaround for Ray bug in reusing GPUs too quickly.
  173. # See https://github.com/ray-project/ray/issues/44821.
  174. for worker in workers:
  175. ray.kill(worker)
  176. time.sleep(1)
  177. return results
  178. def _exec_torch_gpu_cpu_gpu():
  179. i = np.random.randint(100)
  180. device_from = torch.device("cuda:0")
  181. device_to = torch.device("cuda:1")
  182. t = torch.ones(SHAPE, dtype=DTYPE, device=device_from) * i
  183. t = t.to("cpu")
  184. t2 = t.to(device_to)
  185. torch.cuda.synchronize(device_to)
  186. assert (t2[0].item(), t2.shape, t2.dtype) == (i, SHAPE, DTYPE)
  187. def _exec_pickle_cpu():
  188. i = np.random.randint(100)
  189. t = torch.ones(SHAPE, dtype=DTYPE) * i
  190. byte_stream = io.BytesIO()
  191. pickle.dump(t, byte_stream)
  192. byte_stream.seek(0)
  193. pickle.load(byte_stream)
  194. def _exec_pickle_gpu():
  195. i = np.random.randint(100)
  196. t = torch.ones(SHAPE, dtype=DTYPE, device="cuda") * i
  197. byte_stream = io.BytesIO()
  198. pickle.dump(t, byte_stream)
  199. byte_stream.seek(0)
  200. pickle.load(byte_stream)
  201. def _exec_ray_put_cpu():
  202. i = np.random.randint(100)
  203. t = torch.ones(SHAPE, dtype=DTYPE) * i
  204. ray.get(ray.put(t))
  205. def _exec_ray_put_np_zero_copy():
  206. i = np.random.randint(100)
  207. t = torch.ones(SHAPE, dtype=DTYPE) * i
  208. torch.as_tensor(ray.get(ray.put(t.numpy())))
  209. def _exec_ray_put_gpu():
  210. i = np.random.randint(100)
  211. t = torch.ones(SHAPE, dtype=DTYPE, device="cuda") * i
  212. ray.get(ray.put(t))
  213. def exec_ray_dag_cpu(sender_hint, receiver_hint):
  214. sender = TorchTensorWorker.options(scheduling_strategy=sender_hint).remote()
  215. receiver = TorchTensorWorker.options(scheduling_strategy=receiver_hint).remote()
  216. return exec_ray_dag("exec_ray_dag_cpu", sender, receiver)
  217. def exec_ray_core_cpu(sender_hint, receiver_hint):
  218. time.sleep(1)
  219. sender = TorchTensorWorker.options(scheduling_strategy=sender_hint).remote()
  220. receiver = TorchTensorWorker.options(scheduling_strategy=receiver_hint).remote()
  221. return exec_ray_dag("exec_ray_core_cpu", sender, receiver, use_adag=False)
  222. def exec_ray_dag_gpu_ipc_gpu():
  223. time.sleep(1)
  224. sender = TorchIpcWorker.options(num_gpus=1).remote()
  225. receiver = TorchIpcWorker.options(num_gpus=1).remote()
  226. return exec_ray_dag_ipc("exec_ray_dag_gpu_ipc_gpu", sender, receiver)
  227. def exec_ray_dag_gpu_cpu_gpu(sender_hint, receiver_hint):
  228. time.sleep(1)
  229. sender = TorchTensorWorker.options(
  230. num_gpus=1, scheduling_strategy=sender_hint
  231. ).remote()
  232. receiver = TorchTensorWorker.options(
  233. num_gpus=1, scheduling_strategy=receiver_hint
  234. ).remote()
  235. return exec_ray_dag("exec_ray_dag_gpu_cpu_gpu", sender, receiver)
  236. def exec_ray_dag_gpu_nccl(sender_hint, receiver_hint, dynamic_shape: bool = False):
  237. time.sleep(1)
  238. sender = TorchTensorWorker.options(
  239. num_gpus=1, scheduling_strategy=sender_hint
  240. ).remote()
  241. receiver = TorchTensorWorker.options(
  242. num_gpus=1, scheduling_strategy=receiver_hint
  243. ).remote()
  244. return exec_ray_dag(
  245. "exec_ray_dag_gpu_nccl" + ("_dynamic" if dynamic_shape else ""),
  246. sender,
  247. receiver,
  248. use_nccl=True,
  249. dynamic_shape=dynamic_shape,
  250. )
  251. def exec_ray_core_gpu(sender_hint, receiver_hint):
  252. time.sleep(1)
  253. sender = TorchTensorWorker.options(
  254. num_gpus=1, scheduling_strategy=sender_hint
  255. ).remote()
  256. receiver = TorchTensorWorker.options(
  257. num_gpus=1, scheduling_strategy=receiver_hint
  258. ).remote()
  259. return exec_ray_dag("exec_ray_core_gpu", sender, receiver, use_adag=False)
  260. def main(distributed):
  261. results = []
  262. ray.init(
  263. runtime_env={
  264. "env_vars": {
  265. "CUDA_VISIBLE_DEVICES": "0,1",
  266. # Needed for torch distributed.
  267. "MASTER_ADDR": socket.gethostbyname(socket.gethostname()),
  268. "MASTER_PORT": "8888",
  269. }
  270. }
  271. )
  272. # NCCL takes a while to warm up on multi node so increase the default
  273. # timeout.
  274. ctx = DAGContext.get_current()
  275. ctx.retrieval_timeout = 120
  276. sender_hint, receiver_hint = None, None
  277. if distributed:
  278. local_node_id = ray.get_runtime_context().get_node_id()
  279. node_ids = [node["NodeID"] for node in ray.nodes()]
  280. remote_node_ids = [node_id for node_id in node_ids if node_id != local_node_id]
  281. assert remote_node_ids
  282. remote_node_id = remote_node_ids[0]
  283. # Pin sender on local node and receiver on the other node for consistent
  284. # results.
  285. sender_hint = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
  286. local_node_id, soft=False
  287. )
  288. receiver_hint = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
  289. remote_node_id, soft=False
  290. )
  291. if not distributed:
  292. results += timeit("exec_torch_cpu_cpu", _exec_torch_cpu_cpu)
  293. results += timeit("exec_torch_gpu", _exec_torch_gpu)
  294. results += timeit("exec_torch_gpu_cpu_gpu", _exec_torch_gpu_cpu_gpu)
  295. results += exec_nccl_gpu(sender_hint, receiver_hint)
  296. if not distributed:
  297. results += timeit("exec_ray_put_cpu", _exec_ray_put_cpu)
  298. results += timeit("exec_ray_put_np_zero_copy", _exec_ray_put_np_zero_copy)
  299. results += timeit("exec_ray_put_gpu", _exec_ray_put_gpu)
  300. results += exec_ray_core_cpu(sender_hint, receiver_hint)
  301. results += exec_ray_dag_cpu(sender_hint, receiver_hint)
  302. results += exec_ray_core_gpu(sender_hint, receiver_hint)
  303. results += exec_ray_dag_gpu_cpu_gpu(sender_hint, receiver_hint)
  304. results += exec_ray_dag_gpu_nccl(sender_hint, receiver_hint, dynamic_shape=True)
  305. results += exec_ray_dag_gpu_nccl(sender_hint, receiver_hint, dynamic_shape=False)
  306. return results
  307. def to_dict_key(key: str):
  308. for r in [" ", ":", "-"]:
  309. key = key.replace(r, "_")
  310. for r in ["(", ")"]:
  311. key = key.replace(r, "")
  312. return key
  313. if __name__ == "__main__":
  314. import argparse
  315. parser = argparse.ArgumentParser()
  316. parser.add_argument(
  317. "--tensor-size-bytes",
  318. type=int,
  319. # 100KB
  320. default=100_000,
  321. )
  322. parser.add_argument(
  323. "--distributed",
  324. action="store_true",
  325. help="Whether this is running on more than one node",
  326. )
  327. args = parser.parse_args()
  328. # Divide by 2 because we're using torch.float16.
  329. SHAPE = (args.tensor_size_bytes // 2,)
  330. results = main(args.distributed)
  331. result_dict = {
  332. f"{to_dict_key(v[0])}": (v[1], v[2]) for v in results if v is not None
  333. }
  334. perf_metrics = [
  335. {
  336. "perf_metric_name": to_dict_key(v[0]),
  337. "perf_metric_value": v[1],
  338. "perf_metric_type": "THROUGHPUT",
  339. }
  340. for v in results
  341. if v is not None
  342. ]
  343. result_dict["perf_metrics"] = perf_metrics
  344. test_output_json = os.environ.get(
  345. "TEST_OUTPUT_JSON", "/tmp/microbenchmark_gpu.json"
  346. )
  347. with open(test_output_json, "wt") as f:
  348. json.dump(result_dict, f)