12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import click
- import time
- import json
- import os
- import numpy as np
- import torch
- from torchvision import transforms
- from torchvision.models import resnet18
- import ray
- from ray.train.torch import TorchCheckpoint, TorchPredictor
- from ray.train.batch_predictor import BatchPredictor
- from ray.data.preprocessors import TorchVisionPreprocessor
- @click.command(help="Run Batch prediction on Pytorch ResNet models.")
- @click.option("--data-size-gb", type=int, default=1)
- @click.option("--smoke-test", is_flag=True, default=False)
- def main(data_size_gb: int, smoke_test: bool = False):
- data_url = (
- f"s3://anonymous@air-example-data-2/{data_size_gb}G-image-data-synthetic-raw"
- )
- if smoke_test:
- # Only read one image
- data_url = [data_url + "/dog.jpg"]
- print("Running smoke test on CPU with a single example")
- else:
- print(
- f"Running GPU batch prediction with {data_size_gb}GB data from {data_url}"
- )
- start = time.time()
- dataset = ray.data.read_images(data_url, size=(256, 256))
- model = resnet18(pretrained=True)
- def to_tensor(batch: np.ndarray) -> torch.Tensor:
- tensor = torch.as_tensor(batch, dtype=torch.float)
- # (B, H, W, C) -> (B, C, H, W)
- tensor = tensor.permute(0, 3, 1, 2).contiguous()
- # [0., 255.] -> [0., 1.]
- tensor = tensor.div(255)
- return tensor
- transform = transforms.Compose(
- [
- transforms.Lambda(to_tensor),
- transforms.CenterCrop(224),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- ]
- )
- preprocessor = TorchVisionPreprocessor(
- columns=["image"], transform=transform, batched=True
- )
- ckpt = TorchCheckpoint.from_model(model=model, preprocessor=preprocessor)
- predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor)
- predictions = predictor.predict(
- dataset,
- num_gpus_per_worker=int(not smoke_test),
- min_scoring_workers=1,
- max_scoring_workers=1 if smoke_test else int(ray.cluster_resources()["GPU"]),
- batch_size=512,
- )
- for _ in predictions.iter_batches():
- pass
- total_time_s = round(time.time() - start, 2)
- # For structured output integration with internal tooling
- results = {
- "data_size_gb": data_size_gb,
- }
- results["perf_metrics"] = [
- {
- "perf_metric_name": "total_time_s",
- "perf_metric_value": total_time_s,
- "perf_metric_type": "LATENCY",
- },
- {
- "perf_metric_name": "throughout_MB_s",
- "perf_metric_value": (data_size_gb * 1024 / total_time_s),
- "perf_metric_type": "THROUGHPUT",
- },
- ]
- test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/release_test_out.json")
- with open(test_output_json, "wt") as f:
- json.dump(results, f)
- print(results)
- if __name__ == "__main__":
- main()
|