test_e2e_squad.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import subprocess as sp
  2. import datetime
  3. import os
  4. from math import isclose
  5. import sys
  6. import pytest
  7. import json
  8. import argparse
  9. sys.path.append("../../../DeepSpeedExamples/BingBertSquad")
  10. import evaluate as eval
  11. squad_dir = "/data/BingBertSquad"
  12. base_dir = "../../../DeepSpeedExamples/BingBertSquad"
  13. script_file_name = "run_squad_deepspeed.sh"
  14. model_file_name = "training_state_checkpoint_162.tar"
  15. eval_file_name = "dev-v1.1.json"
  16. pred_file_name = "predictions.json"
  17. num_gpus = "4"
  18. timeout_sec = 5 * 60 * 60 # 5 hours
  19. eval_version = "1.1"
  20. def create_config_file(tmpdir, zeroenabled=False):
  21. config_dict = {
  22. "train_batch_size": 24,
  23. "train_micro_batch_size_per_gpu": 6,
  24. "steps_per_print": 10,
  25. "optimizer": {
  26. "type": "Adam",
  27. "params": {
  28. "lr": 3e-5,
  29. "weight_decay": 0.0,
  30. "bias_correction": False
  31. }
  32. },
  33. "gradient_clipping": 1.0,
  34. "fp16": {
  35. "enabled": True
  36. }
  37. }
  38. config_dict["zero_optimization"] = zeroenabled
  39. config_path = os.path.join(tmpdir, 'temp_config.json')
  40. with open(config_path, 'w') as fd:
  41. json.dump(config_dict, fd)
  42. return config_path
  43. def test_e2e_squad_deepspeed_base(tmpdir):
  44. config_file = create_config_file(tmpdir)
  45. # base run results => {"exact_match": 83.9829706717124, "f1": 90.71138132004097}
  46. expected_exact_match = 83.98
  47. expected_f1 = 90.71
  48. model_file = os.path.join(squad_dir, model_file_name)
  49. eval_file = os.path.join(squad_dir, eval_file_name)
  50. output_dir = os.path.join(tmpdir, "output")
  51. pred_file = os.path.join(output_dir, pred_file_name)
  52. proc = sp.Popen([
  53. "bash",
  54. script_file_name,
  55. num_gpus,
  56. model_file,
  57. squad_dir,
  58. output_dir,
  59. config_file
  60. ],
  61. cwd=base_dir)
  62. try:
  63. proc.communicate(timeout=timeout_sec)
  64. if os.path.exists(pred_file):
  65. eval_result = eval.evaluate(eval_version, eval_file, pred_file)
  66. print("evaluation result: ", json.dumps(eval_result))
  67. assert isclose(eval_result["exact_match"],
  68. expected_exact_match,
  69. abs_tol=1e-2)
  70. assert isclose(eval_result["f1"], expected_f1, abs_tol=1e-2)
  71. else:
  72. pytest.fail("Error: Run Failed")
  73. except sp.TimeoutExpired:
  74. proc.kill()
  75. pytest.fail("Error: Timeout")
  76. except sp.CalledProcessError:
  77. pytest.fail("Error: Run Failed")
  78. def test_e2e_squad_deepspeed_zero(tmpdir):
  79. config_file = create_config_file(tmpdir, True)
  80. # base run results => {"exact_match": 84.1438032166509, "f1": 90.89776136505441}
  81. expected_exact_match = 84.14
  82. expected_f1 = 90.89
  83. model_file = os.path.join(squad_dir, model_file_name)
  84. eval_file = os.path.join(squad_dir, eval_file_name)
  85. output_dir = os.path.join(tmpdir, "output")
  86. pred_file = os.path.join(output_dir, pred_file_name)
  87. proc = sp.Popen([
  88. "bash",
  89. script_file_name,
  90. num_gpus,
  91. model_file,
  92. squad_dir,
  93. output_dir,
  94. config_file
  95. ],
  96. cwd=base_dir)
  97. try:
  98. proc.communicate(timeout=timeout_sec)
  99. if os.path.exists(pred_file):
  100. eval_result = eval.evaluate(eval_version, eval_file, pred_file)
  101. print("evaluation result: ", json.dumps(eval_result))
  102. assert isclose(eval_result["exact_match"],
  103. expected_exact_match,
  104. abs_tol=1e-2)
  105. assert isclose(eval_result["f1"], expected_f1, abs_tol=1e-2)
  106. else:
  107. pytest.fail("Error: Run Failed")
  108. except sp.TimeoutExpired:
  109. proc.kill()
  110. pytest.fail("Error: Timeout")
  111. except sp.CalledProcessError:
  112. pytest.fail("Error: Run Failed")