test_inference.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import time
  6. import torch
  7. import pytest
  8. import itertools
  9. import deepspeed
  10. from deepspeed.git_version_info import torch_info
  11. from unit.common import DistributedTest
  12. from packaging import version as pkg_version
  13. from deepspeed.ops.op_builder import OpBuilder
  14. from transformers import pipeline, AutoTokenizer
  15. from transformers.models.t5.modeling_t5 import T5Block
  16. from transformers.models.roberta.modeling_roberta import RobertaLayer
  17. from huggingface_hub import HfApi
  18. from deepspeed.model_implementations import DeepSpeedTransformerInference
  19. from torch import nn
  20. from deepspeed.accelerator import get_accelerator
  21. from deepspeed.ops.op_builder import InferenceBuilder
  22. rocm_version = OpBuilder.installed_rocm_version()
  23. if rocm_version != (0, 0):
  24. pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
  25. _bert_models = [
  26. "bert-base-cased",
  27. "bert-base-uncased",
  28. "bert-large-cased",
  29. "bert-large-uncased",
  30. "bert-base-multilingual-cased",
  31. "bert-base-multilingual-uncased",
  32. "deepset/minilm-uncased-squad2",
  33. "cross-encoder/ms-marco-MiniLM-L-12-v2",
  34. "dslim/bert-base-NER",
  35. "bert-large-uncased-whole-word-masking-finetuned-squad",
  36. "distilbert-base-cased-distilled-squad",
  37. ]
  38. _roberta_models = [
  39. "roberta-large",
  40. "roberta-base",
  41. "deepset/roberta-base-squad2",
  42. "j-hartmann/emotion-english-distilroberta-base",
  43. "Jean-Baptiste/roberta-large-ner-english",
  44. ]
  45. _gpt_models = [
  46. "gpt2",
  47. "distilgpt2",
  48. "Norod78/hebrew-bad_wiki-gpt_neo-tiny",
  49. "EleutherAI/gpt-j-6b",
  50. "EleutherAI/pythia-70m-deduped",
  51. "bigscience/bloom-560m",
  52. ]
  53. _opt_models = [
  54. "facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture.
  55. "facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants.
  56. ]
  57. _test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
  58. _test_tasks = [
  59. "fill-mask", "question-answering", "text-classification", "token-classification", "text-generation",
  60. "text2text-generation", "summarization", "translation"
  61. ]
  62. # Get a list of all models and mapping from task to supported models
  63. _hf_models = list(HfApi().list_models())
  64. _hf_model_names = [m.modelId for m in _hf_models]
  65. _hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}
  66. # Get all combinations of task:model to test
  67. _model_w_tasks = [(m, t) for m, t in itertools.product(*[_test_models, _test_tasks]) if m in _hf_task_to_models[t]]
  68. # Assign to pytest variables for testing
  69. pytest.model_w_tasks = _model_w_tasks
  70. pytest.mt_names = [f"{m}-{t}" for m, t in pytest.model_w_tasks]
  71. @pytest.fixture(scope="module", autouse=True)
  72. def verify_models():
  73. # Verify all test models are registered in HF
  74. _test_models_not_found = [m for m in _test_models if m not in _hf_model_names]
  75. if _test_models_not_found:
  76. pytest.fail(f"Model(s) not found in HuggingFace: {_test_models_not_found}")
  77. # Verify all models are assigned to at least one task
  78. _models_to_be_tested = set(m for m, t in _model_w_tasks)
  79. _missing_task_models = _models_to_be_tested.difference(_test_models)
  80. if _missing_task_models:
  81. pytest.fail(f"Model(s) do not have an assigned task: {_missing_task_models}")
  82. """ Fixtures for inference config """
  83. @pytest.fixture(params=pytest.model_w_tasks, ids=pytest.mt_names)
  84. def model_w_task(request):
  85. return request.param
  86. @pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
  87. def dtype(request):
  88. return request.param
  89. @pytest.fixture(params=[True, False], ids=["CG", "noCG"])
  90. def enable_cuda_graph(request):
  91. return request.param
  92. @pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
  93. def enable_triton(request):
  94. return request.param
  95. """ Fixtures for running query """
  96. @pytest.fixture
  97. def query(model_w_task):
  98. model, task = model_w_task
  99. angle_bracket_mask_models = ["roberta", "camembert", "esm", "ibert", "luke", "mpnet", "yoso", "mpnet"]
  100. if task == "fill-mask":
  101. if any(map(lambda x: x in model, angle_bracket_mask_models)):
  102. return "Hello I'm a <mask> model."
  103. else:
  104. return "Hell I'm a [MASK] model."
  105. elif task == "question-answering":
  106. return {
  107. "question": "What's my name?",
  108. "context": "My name is Clara and I live in Berkeley",
  109. }
  110. elif task == "text-classification":
  111. return "DeepSpeed is the greatest"
  112. elif task == "token-classification":
  113. return "My name is jean-baptiste and I live in montreal."
  114. elif task == "text-generation":
  115. return "DeepSpeed is the greatest"
  116. elif task == "text2text-generation":
  117. return "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
  118. elif task == "translation" or task == "summarization":
  119. return "Hello, my dog is cute"
  120. else:
  121. NotImplementedError(f'query for task "{task}" is not implemented')
  122. @pytest.fixture
  123. def inf_kwargs(model_w_task):
  124. model, task = model_w_task
  125. if task == "text-generation":
  126. if model == "EleutherAI/gpt-j-6b":
  127. # This model on V100 is hitting memory problems that limit the number of output tokens
  128. return {"do_sample": False, "temperature": 1.0, "max_length": 12}
  129. return {"do_sample": False, "temperature": 1.0, "max_length": 20}
  130. else:
  131. return {}
  132. """ Assertion fixture for verifying model outputs """
  133. def fill_mask_assert(x, y):
  134. return set(res["token_str"] for res in x) == set(res["token_str"] for res in y)
  135. def question_answering_assert(x, y):
  136. return x["answer"] == y["answer"]
  137. def text_classification_assert(x, y):
  138. return set(res["label"] for res in x) == set(res["label"] for res in y)
  139. def token_classification_assert(x, y):
  140. return set(ent["word"] for ent in x) == set(ent["word"] for ent in y)
  141. def text_generation_assert(x, y):
  142. return set(res["generated_text"] for res in x) == set(res["generated_text"] for res in y)
  143. def text2text_generation_assert(x, y):
  144. return set(res["generated_text"] for res in x) == set(res["generated_text"] for res in y)
  145. def translation_assert(x, y):
  146. return set(res["translation_text"] for res in x) == set(res["translation_text"] for res in y)
  147. def summarization_assert(x, y):
  148. return set(res["summary_text"] for res in x) == set(res["summary_text"] for res in y)
  149. @pytest.fixture
  150. def assert_fn(model_w_task):
  151. model, task = model_w_task
  152. assert_fn_dict = {
  153. "fill-mask": fill_mask_assert,
  154. "question-answering": question_answering_assert,
  155. "text-classification": text_classification_assert,
  156. "token-classification": token_classification_assert,
  157. "text-generation": text_generation_assert,
  158. "text2text-generation": text2text_generation_assert,
  159. "translation": translation_assert,
  160. "summarization": summarization_assert
  161. }
  162. assert_fn = assert_fn_dict.get(task, None)
  163. if assert_fn is None:
  164. NotImplementedError(f'assert_fn for task "{task}" is not implemented')
  165. return assert_fn
  166. # Used to verify DeepSpeed kernel injection worked with a model
  167. def check_injection(model):
  168. def verify_injection(module):
  169. for child in module.children():
  170. if isinstance(child, nn.ModuleList):
  171. assert isinstance(child[0], DeepSpeedTransformerInference),\
  172. "DeepSpeed-Inference Transformer kernels has not been injected in the model"
  173. break
  174. else:
  175. verify_injection(child)
  176. verify_injection(model)
  177. # Verify that test is valid
  178. def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
  179. model, task = model_w_task
  180. msg = ""
  181. if enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
  182. msg = "CUDA not detected, cannot use CUDA Graph"
  183. elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
  184. msg = "CUDA Graph is only available in torch versions >= 1.10"
  185. elif "gpt-j-6b" in model:
  186. if dtype != torch.half:
  187. msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
  188. elif enable_cuda_graph:
  189. msg = f"Not enough GPU memory to run {model} with CUDA Graph enabled"
  190. elif "gpt-neox-20b" in model: # TODO: remove this when neox issues resolved
  191. msg = "Skipping gpt-neox-20b for now"
  192. elif ("gpt-neox-20b" in model) and (dtype != torch.half):
  193. msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
  194. elif ("bloom" in model) and (dtype != torch.half):
  195. msg = f"Bloom models only support half precision, cannot use dtype {dtype}"
  196. elif ("bert" not in model.lower()) and enable_cuda_graph:
  197. msg = "Non bert/roberta models do no support CUDA Graph"
  198. elif enable_triton and not (dtype in [torch.half]):
  199. msg = "Triton is for fp16"
  200. elif enable_triton and not deepspeed.HAS_TRITON:
  201. msg = "triton needs to be installed for the test"
  202. elif ("bert" not in model.lower()) and enable_triton:
  203. msg = "Triton kernels do not support Non bert/roberta models yet"
  204. # These should be removed once we fix several inference tests failing
  205. if model in ["EleutherAI/pythia-70m-deduped", "distilbert-base-cased-distilled-squad", "EleutherAI/gpt-j-6b"]:
  206. msg = "Test is currently broken"
  207. return msg
  208. @pytest.mark.inference
  209. class TestModelTask(DistributedTest):
  210. world_size = 1
  211. def test(
  212. self,
  213. model_w_task,
  214. dtype,
  215. enable_cuda_graph,
  216. enable_triton,
  217. query,
  218. inf_kwargs,
  219. assert_fn,
  220. perf_meas=True,
  221. ):
  222. invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton)
  223. if invalid_test_msg:
  224. pytest.skip(invalid_test_msg)
  225. model, task = model_w_task
  226. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  227. # Load the model on CPU first to avoid OOM for large models @fp32
  228. pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
  229. if dtype == torch.half:
  230. pipe.model.half()
  231. # Switch device to GPU after converting to half
  232. device = torch.device(get_accelerator().device_name(local_rank))
  233. pipe.device = device
  234. pipe.model.to(device)
  235. # Warm-up queries for perf measurement
  236. #for i in range(10):
  237. # _ = pipe(query, **inf_kwargs)
  238. get_accelerator().synchronize()
  239. start = time.time()
  240. bs_output = pipe(query, **inf_kwargs)
  241. get_accelerator().synchronize()
  242. bs_time = time.time() - start
  243. args = {
  244. 'mp_size': 1,
  245. 'dtype': dtype,
  246. 'replace_with_kernel_inject': True,
  247. 'enable_cuda_graph': enable_cuda_graph,
  248. 'use_triton': enable_triton,
  249. 'triton_autotune': False,
  250. }
  251. if pipe.tokenizer.model_max_length < deepspeed.ops.transformer.inference.config.DeepSpeedInferenceConfig(
  252. ).max_out_tokens:
  253. args.update({'max_out_tokens': pipe.tokenizer.model_max_length})
  254. pipe.model = deepspeed.init_inference(pipe.model, **args)
  255. check_injection(pipe.model)
  256. # Warm-up queries for perf measurement
  257. #for i in range(10):
  258. # _ = pipe(query, **inf_kwargs)
  259. get_accelerator().synchronize()
  260. start = time.time()
  261. ds_output = pipe(query, **inf_kwargs)
  262. get_accelerator().synchronize()
  263. ds_time = time.time() - start
  264. if perf_meas:
  265. print(
  266. f"model={model}, task={task}, dtype={dtype}, cuda_graph={enable_cuda_graph}, triton={enable_triton}, bs_time={bs_time}, ds_time={ds_time}"
  267. )
  268. # facebook/opt* and some bigscient/bloom* models are not matching
  269. # baseline exactly, adding an exception to them for now
  270. if ("opt" in model) or ("bloom" in model):
  271. bs_output = pipe(query, **inf_kwargs)
  272. # These performance tests are only measuring the time for a single
  273. # inference request, we just want to check that performance isn't terrible
  274. #assert ds_time <= (bs_time * 1.1)
  275. assert assert_fn(bs_output, ds_output)
  276. @pytest.mark.seq_inference
  277. @pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
  278. ("EleutherAI/gpt-neox-20b", "text-generation"),
  279. ("bigscience/bloom-3b", "text-generation"),
  280. ("EleutherAI/gpt-j-6b", "text-generation")],
  281. ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"])
  282. class TestMPSize(DistributedTest):
  283. world_size = 2
  284. def test(
  285. self,
  286. model_w_task,
  287. dtype,
  288. query,
  289. inf_kwargs,
  290. assert_fn,
  291. ):
  292. invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
  293. if invalid_test_msg:
  294. pytest.skip(invalid_test_msg)
  295. if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
  296. pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
  297. model, task = model_w_task
  298. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  299. # We have to load these large models on CPU with pipeline because not
  300. # enough GPU memory
  301. pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
  302. bs_output = pipe(query, **inf_kwargs)
  303. pipe.model = deepspeed.init_inference(pipe.model,
  304. mp_size=self.world_size,
  305. dtype=dtype,
  306. replace_with_kernel_inject=True)
  307. check_injection(pipe.model)
  308. # Switch device to GPU so that input tensors are not on CPU
  309. pipe.device = torch.device(get_accelerator().device_name(local_rank))
  310. ds_output = pipe(query, **inf_kwargs)
  311. print(local_rank, "baseline", bs_output)
  312. print(local_rank, "deepspeed", ds_output)
  313. assert assert_fn(bs_output, ds_output)
  314. @pytest.mark.inference
  315. @pytest.mark.parametrize("model_w_task", [("gpt2", "text-generation")], ids=["gpt2"])
  316. class TestLowCpuMemUsage(DistributedTest):
  317. world_size = 1
  318. def test(
  319. self,
  320. model_w_task,
  321. query,
  322. inf_kwargs,
  323. assert_fn,
  324. ):
  325. model, task = model_w_task
  326. dtype = torch.float16
  327. if dtype not in get_accelerator().supported_dtypes():
  328. pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")
  329. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  330. pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=local_rank, framework="pt")
  331. bs_output = pipe(query, **inf_kwargs)
  332. pipe.model = deepspeed.init_inference(pipe.model,
  333. mp_size=self.world_size,
  334. dtype=dtype,
  335. replace_method="auto",
  336. replace_with_kernel_inject=True)
  337. ds_output = pipe(query, **inf_kwargs)
  338. assert assert_fn(bs_output, ds_output)
  339. @pytest.mark.seq_inference
  340. @pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"])
  341. class TestAutoTP(DistributedTest):
  342. world_size = 1
  343. def test(
  344. self,
  345. model_w_task,
  346. query,
  347. inf_kwargs,
  348. assert_fn,
  349. ):
  350. # TODO: enable this test for H100 tests
  351. pytest.skip("Not enough GPU memory for this on V100 runners")
  352. model, task = model_w_task
  353. dtype = torch.bfloat16
  354. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  355. # We have to load these large models on CPU with pipeline because not
  356. # enough GPU memory
  357. tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
  358. pipe = pipeline(task,
  359. model=model,
  360. tokenizer=tokenizer,
  361. torch_dtype=dtype,
  362. trust_remote_code=True,
  363. device=torch.device("cpu"),
  364. framework="pt")
  365. #bs_output = pipe(query, **inf_kwargs)
  366. pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False)
  367. # Switch device to GPU so that input tensors are not on CPU
  368. pipe.device = torch.device(get_accelerator().device_name(local_rank))
  369. ds_output = pipe(query, **inf_kwargs)
  370. #print(local_rank, "baseline", bs_output)
  371. print(local_rank, "deepspeed", ds_output)
  372. #assert assert_fn(bs_output, ds_output)
  373. @pytest.mark.seq_inference
  374. @pytest.mark.parametrize(
  375. "model_w_task, injection_policy",
  376. [
  377. (("google/t5-v1_1-small", "text2text-generation"), {
  378. T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')
  379. }),
  380. (("roberta-large", "fill-mask"), {
  381. RobertaLayer: ('output.dense')
  382. }),
  383. ],
  384. ids=["t5", "roberta"],
  385. )
  386. @pytest.mark.parametrize("dtype", [torch.float], ids=["fp32"])
  387. class TestInjectionPolicy(DistributedTest):
  388. world_size = [1, 2]
  389. def test(
  390. self,
  391. model_w_task,
  392. injection_policy,
  393. query,
  394. inf_kwargs,
  395. assert_fn,
  396. dtype,
  397. ):
  398. invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
  399. if invalid_test_msg:
  400. pytest.skip(invalid_test_msg)
  401. model, task = model_w_task
  402. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  403. world_size = int(os.getenv("WORLD_SIZE", "2"))
  404. # We have to load these large models on CPU with pipeline because not
  405. # enough GPU memory
  406. pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
  407. bs_output = pipe(query, **inf_kwargs)
  408. pipe.model = deepspeed.init_inference(pipe.model,
  409. mp_size=world_size,
  410. dtype=dtype,
  411. injection_policy=injection_policy)
  412. # Switch device to GPU so that input tensors are not on CPU
  413. pipe.device = torch.device(get_accelerator().device_name(local_rank))
  414. ds_output = pipe(query, **inf_kwargs)
  415. print(local_rank, "baseline", bs_output)
  416. print(local_rank, "deepspeed", ds_output)
  417. assert assert_fn(bs_output, ds_output)
  418. @pytest.mark.seq_inference
  419. @pytest.mark.parametrize(
  420. "model_w_task",
  421. [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
  422. ids=["marian", "codegen"], #codegen has fusedqkv weight.
  423. )
  424. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
  425. class TestAutoTensorParallelism(DistributedTest):
  426. world_size = [2]
  427. def test(
  428. self,
  429. model_w_task,
  430. query,
  431. inf_kwargs,
  432. assert_fn,
  433. dtype,
  434. ):
  435. invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
  436. if invalid_test_msg:
  437. pytest.skip(invalid_test_msg)
  438. if dtype not in get_accelerator().supported_dtypes():
  439. pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")
  440. # TODO: enable this test after torch 2.1 stable release
  441. if dtype == torch.bfloat16 and model_w_task[0] == "Salesforce/codegen-350M-mono":
  442. pytest.skip("Codegen model(bf16) need to use torch version > 2.0.")
  443. model, task = model_w_task
  444. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  445. world_size = int(os.getenv("WORLD_SIZE", "2"))
  446. # We have to load these large models on CPU with pipeline because not
  447. # enough GPU memory
  448. pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
  449. bs_output = pipe(query, **inf_kwargs)
  450. pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
  451. # Switch device to GPU so that input tensors are not on CPU
  452. pipe.device = torch.device(get_accelerator().device_name(local_rank))
  453. ds_output = pipe(query, **inf_kwargs)
  454. print(local_rank, "baseline", bs_output)
  455. print(local_rank, "deepspeed", ds_output)
  456. assert assert_fn(bs_output, ds_output)
  457. @pytest.mark.world_size(3)
  458. def test_odd_world_size(
  459. self,
  460. model_w_task,
  461. query,
  462. inf_kwargs,
  463. assert_fn,
  464. dtype,
  465. ):
  466. invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
  467. if invalid_test_msg:
  468. pytest.skip(invalid_test_msg)
  469. model, task = model_w_task
  470. if model == "Salesforce/codegen-350M-mono":
  471. pytest.skip("codegen does not supported by odd world_size")
  472. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  473. world_size = int(os.getenv("WORLD_SIZE", "3"))
  474. pipe = pipeline(task,
  475. model=model,
  476. device=torch.device(get_accelerator().device_name(local_rank)),
  477. framework="pt")
  478. bs_output = pipe(query, **inf_kwargs)
  479. pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
  480. ds_output = pipe(query, **inf_kwargs)
  481. print(local_rank, "baseline", bs_output)
  482. print(local_rank, "deepspeed", ds_output)
  483. assert assert_fn(bs_output, ds_output)
  484. @pytest.mark.nightly
  485. @pytest.mark.parametrize(
  486. "model_family, model_name",
  487. (
  488. ["gpt2", "EleutherAI/gpt-neo-2.7B"],
  489. #["gpt2", "EleutherAI/gpt-j-6b"], # Causing OOM for this test
  490. ["gpt2", "gpt2-xl"],
  491. ),
  492. )
  493. @pytest.mark.parametrize("task", ["lambada_standard"])
  494. class TestLMCorrectness(DistributedTest):
  495. world_size = 1
  496. exec_timeout = 1200 # Give these tests longer to complete
  497. def test(self, model_family, model_name, task):
  498. # imports here to avoid import errors when pytest collects tests
  499. import lm_eval
  500. import lm_eval.models
  501. import lm_eval.tasks
  502. import lm_eval.evaluator
  503. # The bootstrap_stderr function in lm_eval.metrics uses a
  504. # multiprocessing Pool to increase performance. Since we use a Pool for
  505. # our distributed tests and cannot nest Pools, we must redefine and
  506. # patch this function with a version that does not use Pool.
  507. def no_pool_bootstrap_stderr(f, xs, iters):
  508. from lm_eval.metrics import _bootstrap_internal
  509. from lm_eval.metrics import sample_stddev
  510. res = []
  511. chunk_size = min(1000, iters)
  512. for i in range(iters // chunk_size):
  513. res.extend(_bootstrap_internal(f, chunk_size)((i, xs)))
  514. return sample_stddev(res)
  515. lm_eval.metrics.bootstrap_stderr = no_pool_bootstrap_stderr
  516. local_rank = os.getenv("LOCAL_RANK", "0")
  517. device = torch.device(get_accelerator().device_name(local_rank))
  518. dtype = torch.float
  519. task_dict = lm_eval.tasks.get_task_dict([task])
  520. if 'gpt-j-6b' in model_name:
  521. dtype = torch.half
  522. lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
  523. {"device": "cpu"})
  524. setattr(lm, model_family, getattr(lm, model_family).half().to(device))
  525. lm._device = device
  526. else:
  527. lm = lm_eval.models.get_model(model_family).create_from_arg_string(
  528. f"pretrained={model_name}", {"device": get_accelerator().device_name()})
  529. get_accelerator().synchronize()
  530. start = time.time()
  531. bs_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict)
  532. get_accelerator().synchronize()
  533. bs_time = time.time() - start
  534. getattr(lm, model_family).to("cpu")
  535. ds_model = deepspeed.init_inference(
  536. getattr(lm, model_family),
  537. mp_size=1,
  538. dtype=dtype,
  539. replace_with_kernel_inject=True,
  540. enable_cuda_graph=False,
  541. )
  542. check_injection(ds_model)
  543. setattr(lm, model_family, ds_model)
  544. get_accelerator().synchronize()
  545. start = time.time()
  546. ds_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict)
  547. get_accelerator().synchronize()
  548. ds_time = time.time() - start
  549. ppl_diff = abs(bs_output["results"][task]["ppl"] - ds_output["results"][task]["ppl"])
  550. #assert ds_time <= bs_time
  551. assert ppl_diff < 0.01