gpu_batch_prediction.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import click
  2. import time
  3. import json
  4. import os
  5. import numpy as np
  6. import torch
  7. from torchvision import transforms
  8. from torchvision.models import resnet18
  9. import ray
  10. from ray.train.torch import TorchCheckpoint, TorchPredictor
  11. from ray.train.batch_predictor import BatchPredictor
  12. from ray.data.preprocessors import TorchVisionPreprocessor
  13. @click.command(help="Run Batch prediction on Pytorch ResNet models.")
  14. @click.option("--data-size-gb", type=int, default=1)
  15. @click.option("--smoke-test", is_flag=True, default=False)
  16. def main(data_size_gb: int, smoke_test: bool = False):
  17. data_url = (
  18. f"s3://anonymous@air-example-data-2/{data_size_gb}G-image-data-synthetic-raw"
  19. )
  20. if smoke_test:
  21. # Only read one image
  22. data_url = [data_url + "/dog.jpg"]
  23. print("Running smoke test on CPU with a single example")
  24. else:
  25. print(
  26. f"Running GPU batch prediction with {data_size_gb}GB data from {data_url}"
  27. )
  28. start = time.time()
  29. dataset = ray.data.read_images(data_url, size=(256, 256))
  30. model = resnet18(pretrained=True)
  31. def to_tensor(batch: np.ndarray) -> torch.Tensor:
  32. tensor = torch.as_tensor(batch, dtype=torch.float)
  33. # (B, H, W, C) -> (B, C, H, W)
  34. tensor = tensor.permute(0, 3, 1, 2).contiguous()
  35. # [0., 255.] -> [0., 1.]
  36. tensor = tensor.div(255)
  37. return tensor
  38. transform = transforms.Compose(
  39. [
  40. transforms.Lambda(to_tensor),
  41. transforms.CenterCrop(224),
  42. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  43. ]
  44. )
  45. preprocessor = TorchVisionPreprocessor(
  46. columns=["image"], transform=transform, batched=True
  47. )
  48. ckpt = TorchCheckpoint.from_model(model=model, preprocessor=preprocessor)
  49. predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor)
  50. predictions = predictor.predict(
  51. dataset,
  52. num_gpus_per_worker=int(not smoke_test),
  53. min_scoring_workers=1,
  54. max_scoring_workers=1 if smoke_test else int(ray.cluster_resources()["GPU"]),
  55. batch_size=512,
  56. )
  57. for _ in predictions.iter_batches():
  58. pass
  59. total_time_s = round(time.time() - start, 2)
  60. # For structured output integration with internal tooling
  61. results = {
  62. "data_size_gb": data_size_gb,
  63. }
  64. results["perf_metrics"] = [
  65. {
  66. "perf_metric_name": "total_time_s",
  67. "perf_metric_value": total_time_s,
  68. "perf_metric_type": "LATENCY",
  69. },
  70. {
  71. "perf_metric_name": "throughout_MB_s",
  72. "perf_metric_value": (data_size_gb * 1024 / total_time_s),
  73. "perf_metric_type": "THROUGHPUT",
  74. },
  75. ]
  76. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/release_test_out.json")
  77. with open(test_output_json, "wt") as f:
  78. json.dump(results, f)
  79. print(results)
  80. if __name__ == "__main__":
  81. main()