torch_benchmark.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. import json
  2. import os
  3. import time
  4. from pathlib import Path
  5. from typing import Dict, Tuple
  6. import click
  7. import numpy as np
  8. import torch
  9. from torch import nn, distributed
  10. from torch.utils.data import DataLoader, DistributedSampler
  11. from torch.utils.data.dataloader import default_collate
  12. from torchvision import datasets
  13. from torchvision.transforms import ToTensor
  14. CONFIG = {"lr": 1e-3, "batch_size": 64}
  15. VANILLA_RESULT_JSON = "/tmp/vanilla_out.json"
  16. def find_network_interface():
  17. for iface in os.listdir("/sys/class/net"):
  18. if iface.startswith("ens"):
  19. network_interface = iface
  20. break
  21. else:
  22. network_interface = "^lo,docker"
  23. return network_interface
  24. # Define model
  25. class NeuralNetwork(nn.Module):
  26. def __init__(self):
  27. super(NeuralNetwork, self).__init__()
  28. self.flatten = nn.Flatten()
  29. self.linear_relu_stack = nn.Sequential(
  30. nn.Linear(28 * 28, 512),
  31. nn.ReLU(),
  32. nn.Linear(512, 512),
  33. nn.ReLU(),
  34. nn.Linear(512, 10),
  35. nn.ReLU(),
  36. )
  37. def forward(self, x):
  38. x = self.flatten(x)
  39. logits = self.linear_relu_stack(x)
  40. return logits
  41. def train_epoch(
  42. dataloader, model, loss_fn, optimizer, world_size: int, local_rank: int
  43. ):
  44. size = len(dataloader.dataset) // world_size
  45. model.train()
  46. for batch, (X, y) in enumerate(dataloader):
  47. # Compute prediction error
  48. pred = model(X)
  49. loss = loss_fn(pred, y)
  50. # Backpropagation
  51. optimizer.zero_grad()
  52. loss.backward()
  53. optimizer.step()
  54. if batch % 100 == 0:
  55. loss, current = loss.item(), batch * len(X)
  56. print(f"[rank={local_rank}] loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
  57. def validate_epoch(dataloader, model, loss_fn, world_size: int, local_rank: int):
  58. size = len(dataloader.dataset) // world_size
  59. num_batches = len(dataloader)
  60. model.eval()
  61. test_loss, correct = 0, 0
  62. with torch.no_grad():
  63. for X, y in dataloader:
  64. pred = model(X)
  65. test_loss += loss_fn(pred, y).item()
  66. correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  67. test_loss /= num_batches
  68. correct /= size
  69. print(
  70. f"[rank={local_rank}] Test Error: \n "
  71. f"Accuracy: {(100 * correct):>0.1f}%, "
  72. f"Avg loss: {test_loss:>8f} \n"
  73. )
  74. return test_loss
  75. def train_func(use_ray: bool, config: Dict):
  76. local_start_time = time.monotonic()
  77. if use_ray:
  78. from ray.air import session
  79. import ray.train as train
  80. batch_size = config["batch_size"]
  81. lr = config["lr"]
  82. epochs = config["epochs"]
  83. shuffle = config.get("shuffle", False)
  84. if use_ray:
  85. world_size = session.get_world_size()
  86. local_rank = distributed.get_rank()
  87. else:
  88. world_size = distributed.get_world_size()
  89. local_rank = distributed.get_rank()
  90. worker_batch_size = batch_size // world_size
  91. # Load datasets. Use download=False to catch errors in preparation, as the
  92. # data should have already been downloaded.
  93. training_data = datasets.FashionMNIST(
  94. root="/tmp/data_fashion_mnist",
  95. train=True,
  96. download=False,
  97. transform=ToTensor(),
  98. )
  99. test_data = datasets.FashionMNIST(
  100. root="/tmp/data_fashion_mnist",
  101. train=False,
  102. download=False,
  103. transform=ToTensor(),
  104. )
  105. if use_ray:
  106. # Ray adds DistributedSampler in train.torch.prepare_data_loader below
  107. training_sampler = None
  108. test_sampler = None
  109. else:
  110. # In vanilla PyTorch we create the distributed sampler here
  111. training_sampler = DistributedSampler(training_data, shuffle=shuffle)
  112. test_sampler = DistributedSampler(test_data, shuffle=shuffle)
  113. if not use_ray and config.get("use_gpu", False):
  114. assert torch.cuda.is_available(), "No GPUs available"
  115. gpu_id = config.get("gpu_id", 0)
  116. vanilla_device = torch.device(f"cuda:{gpu_id}")
  117. torch.cuda.set_device(vanilla_device)
  118. print(
  119. "Setting GPU ID to",
  120. gpu_id,
  121. "with visible devices",
  122. os.environ.get("CUDA_VISIBLE_DEVICES"),
  123. )
  124. def collate_fn(x):
  125. return tuple(x_.to(vanilla_device) for x_ in default_collate(x))
  126. else:
  127. vanilla_device = torch.device("cpu")
  128. collate_fn = None
  129. # Create data loaders and potentially pass distributed sampler
  130. train_dataloader = DataLoader(
  131. training_data,
  132. shuffle=shuffle,
  133. batch_size=worker_batch_size,
  134. sampler=training_sampler,
  135. collate_fn=collate_fn,
  136. )
  137. test_dataloader = DataLoader(
  138. test_data,
  139. shuffle=shuffle,
  140. batch_size=worker_batch_size,
  141. sampler=test_sampler,
  142. collate_fn=collate_fn,
  143. )
  144. if use_ray:
  145. # In Ray, we now retrofit the DistributedSampler
  146. train_dataloader = train.torch.prepare_data_loader(train_dataloader)
  147. test_dataloader = train.torch.prepare_data_loader(test_dataloader)
  148. # Create model.
  149. model = NeuralNetwork()
  150. # Prepare model
  151. if use_ray:
  152. model = train.torch.prepare_model(model)
  153. else:
  154. model = model.to(vanilla_device)
  155. if config.get("use_gpu", False):
  156. model = nn.parallel.DistributedDataParallel(
  157. model, device_ids=[gpu_id], output_device=gpu_id
  158. )
  159. else:
  160. model = nn.parallel.DistributedDataParallel(model)
  161. loss_fn = nn.CrossEntropyLoss()
  162. optimizer = torch.optim.SGD(model.parameters(), lr=lr)
  163. for _ in range(epochs):
  164. train_epoch(
  165. train_dataloader,
  166. model,
  167. loss_fn,
  168. optimizer,
  169. world_size=world_size,
  170. local_rank=local_rank,
  171. )
  172. loss = validate_epoch(
  173. test_dataloader,
  174. model,
  175. loss_fn,
  176. world_size=world_size,
  177. local_rank=local_rank,
  178. )
  179. local_time_taken = time.monotonic() - local_start_time
  180. if use_ray:
  181. session.report(dict(loss=loss, local_time_taken=local_time_taken))
  182. else:
  183. print(f"Reporting loss: {loss:.4f}")
  184. if local_rank == 0:
  185. with open(VANILLA_RESULT_JSON, "w") as f:
  186. json.dump({"loss": loss, "local_time_taken": local_time_taken}, f)
  187. def train_torch_ray_air(
  188. *,
  189. config: dict,
  190. num_workers: int = 4,
  191. cpus_per_worker: int = 8,
  192. use_gpu: bool = False,
  193. ) -> Tuple[float, float, float]:
  194. # This function is kicked off by the main() function and runs a full training
  195. # run using Ray AIR.
  196. from ray.train.torch import TorchTrainer
  197. from ray.air.config import ScalingConfig
  198. def train_loop(config):
  199. train_func(use_ray=True, config=config)
  200. start_time = time.monotonic()
  201. trainer = TorchTrainer(
  202. train_loop_per_worker=train_loop,
  203. train_loop_config=config,
  204. scaling_config=ScalingConfig(
  205. trainer_resources={"CPU": 0},
  206. num_workers=num_workers,
  207. resources_per_worker={"CPU": cpus_per_worker},
  208. use_gpu=use_gpu,
  209. ),
  210. )
  211. result = trainer.fit()
  212. time_taken = time.monotonic() - start_time
  213. print(f"Last result: {result.metrics}")
  214. return time_taken, result.metrics["local_time_taken"], result.metrics["loss"]
  215. def train_torch_vanilla_worker(
  216. *,
  217. config: dict,
  218. rank: int,
  219. world_size: int,
  220. master_addr: str,
  221. master_port: int,
  222. use_gpu: bool = False,
  223. gpu_id: int = 0,
  224. ):
  225. # This function is kicked off by the main() function and runs the vanilla
  226. # training script on a single worker.
  227. backend = "nccl" if use_gpu else "gloo"
  228. os.environ["MASTER_ADDR"] = master_addr
  229. os.environ["MASTER_PORT"] = str(master_port)
  230. os.environ["NCCL_BLOCKING_WAIT"] = "1"
  231. distributed.init_process_group(
  232. backend=backend, rank=rank, world_size=world_size, init_method="env://"
  233. )
  234. config["use_gpu"] = use_gpu
  235. config["gpu_id"] = gpu_id
  236. train_func(use_ray=False, config=config)
  237. distributed.destroy_process_group()
  238. def train_torch_vanilla(
  239. *,
  240. config: dict,
  241. num_workers: int = 4,
  242. cpus_per_worker: int = 8,
  243. use_gpu: bool = False,
  244. ) -> Tuple[float, float, float]:
  245. # This function is kicked off by the main() function and subsequently kicks
  246. # off tasks that run train_torch_vanilla_worker() on the worker nodes.
  247. from benchmark_util import (
  248. upload_file_to_all_nodes,
  249. create_actors_with_options,
  250. run_commands_on_actors,
  251. run_fn_on_actors,
  252. get_ip_port_actors,
  253. get_gpu_ids_actors,
  254. map_ips_to_gpus,
  255. set_cuda_visible_devices,
  256. )
  257. path = os.path.abspath(__file__)
  258. upload_file_to_all_nodes(path)
  259. num_epochs = config["epochs"]
  260. try:
  261. nccl_network_interface = find_network_interface()
  262. runtime_env = {"env_vars": {"NCCL_SOCKET_IFNAME": nccl_network_interface}}
  263. except Exception:
  264. runtime_env = {}
  265. actors = create_actors_with_options(
  266. num_actors=num_workers,
  267. resources={
  268. "CPU": cpus_per_worker,
  269. "GPU": int(use_gpu),
  270. },
  271. runtime_env=runtime_env,
  272. )
  273. run_fn_on_actors(actors=actors, fn=lambda: os.environ.pop("OMP_NUM_THREADS", None))
  274. # Get IPs and ports for all actors
  275. ip_ports = get_ip_port_actors(actors=actors)
  276. # Rank 0 is the master addr/port
  277. master_addr, master_port = ip_ports[0]
  278. if use_gpu:
  279. # Extract IPs
  280. actor_ips = [ipp[0] for ipp in ip_ports]
  281. # Get allocated GPU IDs for all actors
  282. gpu_ids = get_gpu_ids_actors(actors=actors)
  283. # Build a map of IP to all allocated GPUs on that machine
  284. ip_to_gpu_map = map_ips_to_gpus(ips=actor_ips, gpus=gpu_ids)
  285. # Set the environment variables on the workers
  286. set_cuda_visible_devices(
  287. actors=actors, actor_ips=actor_ips, ip_to_gpus=ip_to_gpu_map
  288. )
  289. use_gpu_ids = [gi[0] for gi in gpu_ids]
  290. else:
  291. use_gpu_ids = [0] * num_workers
  292. cmds = [
  293. [
  294. "python",
  295. path,
  296. "worker",
  297. "--num-epochs",
  298. str(num_epochs),
  299. "--num-workers",
  300. str(num_workers),
  301. "--rank",
  302. str(rank),
  303. "--master-addr",
  304. master_addr,
  305. "--master-port",
  306. str(master_port),
  307. "--batch-size",
  308. str(config["batch_size"]),
  309. ]
  310. + (["--use-gpu"] if use_gpu else [])
  311. + (["--gpu-id", str(use_gpu_ids[rank])] if use_gpu else [])
  312. for rank in range(num_workers)
  313. ]
  314. run_fn_on_actors(
  315. actors=actors, fn=lambda: os.environ.setdefault("OMP_NUM_THREADS", "1")
  316. )
  317. start_time = time.monotonic()
  318. run_commands_on_actors(actors=actors, cmds=cmds)
  319. time_taken = time.monotonic() - start_time
  320. loss = 0.0
  321. if os.path.exists(VANILLA_RESULT_JSON):
  322. with open(VANILLA_RESULT_JSON, "r") as f:
  323. result = json.load(f)
  324. loss = result["loss"]
  325. local_time_taken = result["local_time_taken"]
  326. return time_taken, local_time_taken, loss
  327. @click.group(help="Run Torch benchmarks")
  328. def cli():
  329. pass
  330. @cli.command(help="Kick off Ray and vanilla benchmarks")
  331. @click.option("--num-runs", type=int, default=1)
  332. @click.option("--num-epochs", type=int, default=4)
  333. @click.option("--num-workers", type=int, default=4)
  334. @click.option("--cpus-per-worker", type=int, default=8)
  335. @click.option("--use-gpu", is_flag=True, default=False)
  336. @click.option("--batch-size", type=int, default=64)
  337. @click.option("--smoke-test", is_flag=True, default=False)
  338. @click.option("--local", is_flag=True, default=False)
  339. def run(
  340. num_runs: int = 1,
  341. num_epochs: int = 4,
  342. num_workers: int = 4,
  343. cpus_per_worker: int = 8,
  344. use_gpu: bool = False,
  345. batch_size: int = 64,
  346. smoke_test: bool = False,
  347. local: bool = False,
  348. ):
  349. # Note: smoke_test is ignored as we just adjust the batch size.
  350. # The parameter is passed by the release test pipeline.
  351. import ray
  352. from benchmark_util import upload_file_to_all_nodes, run_command_on_all_nodes
  353. config = CONFIG.copy()
  354. config["epochs"] = num_epochs
  355. config["batch_size"] = batch_size
  356. if local:
  357. ray.init(num_cpus=4)
  358. else:
  359. ray.init("auto")
  360. print("Preparing Torch benchmark: Downloading MNIST")
  361. path = str((Path(__file__).parent / "_torch_prepare.py").absolute())
  362. upload_file_to_all_nodes(path)
  363. run_command_on_all_nodes(["python", path])
  364. times_ray = []
  365. times_local_ray = []
  366. losses_ray = []
  367. times_vanilla = []
  368. times_local_vanilla = []
  369. losses_vanilla = []
  370. for run in range(1, num_runs + 1):
  371. time.sleep(2)
  372. print(f"[Run {run}/{num_runs}] Running Torch Ray benchmark")
  373. time_ray, time_local_ray, loss_ray = train_torch_ray_air(
  374. num_workers=num_workers,
  375. cpus_per_worker=cpus_per_worker,
  376. use_gpu=use_gpu,
  377. config=config,
  378. )
  379. print(
  380. f"[Run {run}/{num_runs}] Finished Ray training ({num_epochs} epochs) in "
  381. f"{time_ray:.2f} seconds (local training time: {time_local_ray:.2f}s). "
  382. f"Observed loss = {loss_ray:.4f}"
  383. )
  384. time.sleep(2)
  385. print(f"[Run {run}/{num_runs}] Running Torch vanilla benchmark")
  386. time_vanilla, time_local_vanilla, loss_vanilla = train_torch_vanilla(
  387. num_workers=num_workers,
  388. cpus_per_worker=cpus_per_worker,
  389. use_gpu=use_gpu,
  390. config=config,
  391. )
  392. print(
  393. f"[Run {run}/{num_runs}] Finished vanilla training ({num_epochs} epochs) "
  394. f"in {time_vanilla:.2f} seconds "
  395. f"(local training time: {time_local_vanilla:.2f}s). "
  396. f"Observed loss = {loss_vanilla:.4f}"
  397. )
  398. print(
  399. f"[Run {run}/{num_runs}] Observed results: ",
  400. {
  401. "tensorflow_mnist_ray_time_s": time_ray,
  402. "tensorflow_mnist_ray_local_time_s": time_local_ray,
  403. "tensorflow_mnist_ray_loss": loss_ray,
  404. "tensorflow_mnist_vanilla_time_s": time_vanilla,
  405. "tensorflow_mnist_vanilla_local_time_s": time_local_vanilla,
  406. "tensorflow_mnist_vanilla_loss": loss_vanilla,
  407. },
  408. )
  409. times_ray.append(time_ray)
  410. times_local_ray.append(time_local_ray)
  411. losses_ray.append(loss_ray)
  412. times_vanilla.append(time_vanilla)
  413. times_local_vanilla.append(time_local_vanilla)
  414. losses_vanilla.append(loss_vanilla)
  415. times_ray_mean = np.mean(times_ray)
  416. times_ray_sd = np.std(times_ray)
  417. times_local_ray_mean = np.mean(times_local_ray)
  418. times_local_ray_sd = np.std(times_local_ray)
  419. times_vanilla_mean = np.mean(times_vanilla)
  420. times_vanilla_sd = np.std(times_vanilla)
  421. times_local_vanilla_mean = np.mean(times_local_vanilla)
  422. times_local_vanilla_sd = np.std(times_local_vanilla)
  423. result = {
  424. "torch_mnist_ray_num_runs": num_runs,
  425. "torch_mnist_ray_time_s_all": times_ray,
  426. "torch_mnist_ray_time_s_mean": times_ray_mean,
  427. "torch_mnist_ray_time_s_sd": times_ray_sd,
  428. "torch_mnist_ray_time_local_s_all": times_local_ray,
  429. "torch_mnist_ray_time_local_s_mean": times_local_ray_mean,
  430. "torch_mnist_ray_time_local_s_sd": times_local_ray_sd,
  431. "torch_mnist_ray_loss_mean": np.mean(losses_ray),
  432. "torch_mnist_ray_loss_sd": np.std(losses_ray),
  433. "torch_mnist_vanilla_time_s_all": times_vanilla,
  434. "torch_mnist_vanilla_time_s_mean": times_vanilla_mean,
  435. "torch_mnist_vanilla_time_s_sd": times_vanilla_sd,
  436. "torch_mnist_vanilla_local_time_s_all": times_local_vanilla,
  437. "torch_mnist_vanilla_local_time_s_mean": times_local_vanilla_mean,
  438. "torch_mnist_vanilla_local_time_s_sd": times_local_vanilla_sd,
  439. "torch_mnist_vanilla_loss_mean": np.mean(losses_vanilla),
  440. "torch_mnist_vanilla_loss_std": np.std(losses_vanilla),
  441. }
  442. print("Results:", result)
  443. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/result.json")
  444. with open(test_output_json, "wt") as f:
  445. json.dump(result, f)
  446. target_ratio = 1.15
  447. ratio = (
  448. (times_local_ray_mean / times_local_vanilla_mean)
  449. if times_local_vanilla_mean != 0.0
  450. else 1.0
  451. )
  452. if ratio > target_ratio:
  453. raise RuntimeError(
  454. f"Training on Ray took an average of {times_local_ray_mean:.2f} seconds, "
  455. f"which is more than {target_ratio:.2f}x of the average vanilla training "
  456. f"time of {times_local_vanilla_mean:.2f} seconds ({ratio:.2f}x). FAILED"
  457. )
  458. print(
  459. f"Training on Ray took an average of {times_local_ray_mean:.2f} seconds, "
  460. f"which is less than {target_ratio:.2f}x of the average vanilla training "
  461. f"time of {times_local_vanilla_mean:.2f} seconds ({ratio:.2f}x). PASSED"
  462. )
  463. @cli.command(help="Run PyTorch vanilla worker")
  464. @click.option("--num-epochs", type=int, default=4)
  465. @click.option("--num-workers", type=int, default=4)
  466. @click.option("--rank", type=int, default=0)
  467. @click.option("--master-addr", type=str, default="")
  468. @click.option("--master-port", type=int, default=0)
  469. @click.option("--batch-size", type=int, default=64)
  470. @click.option("--use-gpu", is_flag=True, default=False)
  471. @click.option("--gpu-id", type=int, default=0)
  472. def worker(
  473. num_epochs: int = 4,
  474. num_workers: int = 4,
  475. rank: int = 0,
  476. master_addr: str = "",
  477. master_port: int = 0,
  478. batch_size: int = 64,
  479. use_gpu: bool = False,
  480. gpu_id: int = 0,
  481. ):
  482. config = CONFIG.copy()
  483. config["epochs"] = num_epochs
  484. config["batch_size"] = batch_size
  485. # Then we kick off the training function on every worker.
  486. return train_torch_vanilla_worker(
  487. config=config,
  488. rank=rank,
  489. world_size=num_workers,
  490. master_addr=master_addr,
  491. master_port=master_port,
  492. use_gpu=use_gpu,
  493. gpu_id=gpu_id,
  494. )
  495. def main():
  496. return cli()
  497. if __name__ == "__main__":
  498. main()