train_batch_inference_benchmark.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import json
  2. import numpy as np
  3. import os
  4. import pandas as pd
  5. import time
  6. from typing import Dict
  7. import xgboost as xgb
  8. import ray
  9. from ray import data
  10. from ray.train.lightgbm import LightGBMTrainer
  11. from ray.train.xgboost import XGBoostTrainer
  12. from ray.train import RunConfig, ScalingConfig
  13. _TRAINING_TIME_THRESHOLD = 600
  14. _PREDICTION_TIME_THRESHOLD = 450
  15. _EXPERIMENT_PARAMS = {
  16. "smoke_test": {
  17. "data": (
  18. "https://air-example-data-2.s3.us-west-2.amazonaws.com/"
  19. "10G-xgboost-data.parquet/8034b2644a1d426d9be3bbfa78673dfa_000000.parquet"
  20. ),
  21. "num_workers": 1,
  22. "cpus_per_worker": 1,
  23. },
  24. "10G": {
  25. "data": "s3://air-example-data-2/10G-xgboost-data.parquet/",
  26. "num_workers": 1,
  27. "cpus_per_worker": 12,
  28. },
  29. "100G": {
  30. "data": "s3://air-example-data-2/100G-xgboost-data.parquet/",
  31. "num_workers": 10,
  32. "cpus_per_worker": 12,
  33. },
  34. }
  35. class BasePredictor:
  36. def __init__(self, trainer_cls, result: ray.train.Result):
  37. self.model = trainer_cls.get_model(result.checkpoint)
  38. def __call__(self, data):
  39. raise NotImplementedError
  40. class XGBoostPredictor(BasePredictor):
  41. def __call__(self, data: pd.DataFrame) -> Dict[str, np.ndarray]:
  42. dmatrix = xgb.DMatrix(data)
  43. return {"predictions": self.model.predict(dmatrix)}
  44. class LightGBMPredictor(BasePredictor):
  45. def __call__(self, data: pd.DataFrame) -> Dict[str, np.ndarray]:
  46. return {"predictions": self.model.predict(data)}
  47. _FRAMEWORK_PARAMS = {
  48. "xgboost": {
  49. "trainer_cls": XGBoostTrainer,
  50. "predictor_cls": XGBoostPredictor,
  51. "params": {
  52. "objective": "binary:logistic",
  53. "eval_metric": ["logloss", "error"],
  54. },
  55. },
  56. "lightgbm": {
  57. "trainer_cls": LightGBMTrainer,
  58. "predictor_cls": LightGBMPredictor,
  59. "params": {
  60. "objective": "binary",
  61. "metric": ["binary_logloss", "binary_error"],
  62. },
  63. },
  64. }
  65. def train(
  66. framework: str, data_path: str, num_workers: int, cpus_per_worker: int
  67. ) -> ray.train.Result:
  68. ds = data.read_parquet(data_path)
  69. framework_params = _FRAMEWORK_PARAMS[framework]
  70. trainer_cls = framework_params["trainer_cls"]
  71. trainer = trainer_cls(
  72. params=framework_params["params"],
  73. scaling_config=ScalingConfig(
  74. num_workers=num_workers,
  75. resources_per_worker={"CPU": cpus_per_worker},
  76. trainer_resources={"CPU": 0},
  77. ),
  78. label_column="labels",
  79. datasets={"train": ds},
  80. run_config=RunConfig(
  81. storage_path="/mnt/cluster_storage", name=f"{framework}_benchmark"
  82. ),
  83. )
  84. result = trainer.fit()
  85. return result
  86. def predict(framework: str, result: ray.train.Result, data_path: str):
  87. framework_params = _FRAMEWORK_PARAMS[framework]
  88. predictor_cls = framework_params["predictor_cls"]
  89. ds = data.read_parquet(data_path)
  90. ds = ds.drop_columns(["labels"])
  91. concurrency = int(ray.cluster_resources()["CPU"] // 2)
  92. result = ds.map_batches(
  93. predictor_cls,
  94. # Improve prediction throughput with larger batch size than default 4096
  95. batch_size=8192,
  96. concurrency=concurrency,
  97. fn_constructor_kwargs={
  98. "trainer_cls": framework_params["trainer_cls"],
  99. "result": result,
  100. },
  101. batch_format="pandas",
  102. )
  103. for _ in result.iter_batches():
  104. pass
  105. def main(args):
  106. framework = args.framework
  107. experiment = args.size if not args.smoke_test else "smoke_test"
  108. experiment_params = _EXPERIMENT_PARAMS[experiment]
  109. data_path, num_workers, cpus_per_worker = (
  110. experiment_params["data"],
  111. experiment_params["num_workers"],
  112. experiment_params["cpus_per_worker"],
  113. )
  114. print(f"Running {framework} training benchmark...")
  115. training_start = time.perf_counter()
  116. result = train(framework, data_path, num_workers, cpus_per_worker)
  117. training_time = time.perf_counter() - training_start
  118. print(f"Running {framework} prediction benchmark...")
  119. prediction_start = time.perf_counter()
  120. predict(framework, result, data_path)
  121. prediction_time = time.perf_counter() - prediction_start
  122. times = {"training_time": training_time, "prediction_time": prediction_time}
  123. print("Training result:\n", result)
  124. print("Training/prediction times:", times)
  125. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/result.json")
  126. with open(test_output_json, "wt") as f:
  127. json.dump(times, f)
  128. if not args.disable_check:
  129. if training_time > _TRAINING_TIME_THRESHOLD:
  130. raise RuntimeError(
  131. f"Training is taking {training_time} seconds, "
  132. f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
  133. )
  134. if prediction_time > _PREDICTION_TIME_THRESHOLD:
  135. raise RuntimeError(
  136. f"Batch prediction is taking {prediction_time} seconds, "
  137. f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
  138. )
  139. if __name__ == "__main__":
  140. import argparse
  141. parser = argparse.ArgumentParser()
  142. parser.add_argument(
  143. "framework", type=str, choices=["xgboost", "lightgbm"], default="xgboost"
  144. )
  145. parser.add_argument("--size", type=str, choices=["10G", "100G"], default="100G")
  146. # Add a flag for disabling the timeout error.
  147. # Use case: running the benchmark as a documented example, in infra settings
  148. # different from the formal benchmark's EC2 setup.
  149. parser.add_argument(
  150. "--disable-check",
  151. action="store_true",
  152. help="disable runtime error on benchmark timeout",
  153. )
  154. parser.add_argument("--smoke-test", action="store_true")
  155. args = parser.parse_args()
  156. main(args)