test_sparse_attention.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # DeepSpeed note, some parts of code taken & adapted from commit c368a9fd1b2c9dee4cc94de9a6bb0be3d447be41
  2. # https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_softmax.py
  3. # https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_matmul.py
  4. # https://github.com/ptillet/torch-blocksparse/blob/master/tests/utils
  5. import pytest
  6. import torch
  7. import deepspeed
  8. if not deepspeed.ops.__installed_ops__['sparse-attn']:
  9. pytest.skip("cpu-adam is not installed", allow_module_level=True)
  10. def test_sparse_attention_module_availability():
  11. try:
  12. from deepspeed.ops import sparse_attention
  13. except ImportError:
  14. print("Sparse Attention Module is not installed!")
  15. return False
  16. return True
  17. def test_matmul_module_availability():
  18. try:
  19. from deepspeed.ops.sparse_attention import MatMul
  20. except ImportError:
  21. print("Sparse MatMul Module is not installed!")
  22. return False
  23. return True
  24. def test_softmax_module_availability():
  25. try:
  26. from deepspeed.ops.sparse_attention import Softmax
  27. except ImportError:
  28. print("Sparse Softmax Module is not installed!")
  29. return False
  30. return True
  31. def test_sparsityconfig_module_availability():
  32. try:
  33. from deepspeed.ops.sparse_attention import SparsityConfig
  34. except ImportError:
  35. print("SparsityConfig Module is not installed!")
  36. return False
  37. return True
  38. def test_densesparsityconfig_module_availability():
  39. try:
  40. from deepspeed.ops.sparse_attention import DenseSparsityConfig
  41. except ImportError:
  42. print("DenseSparsityConfig Module is not installed!")
  43. return False
  44. return True
  45. def test_fixedsparsityconfig_module_availability():
  46. try:
  47. from deepspeed.ops.sparse_attention import FixedSparsityConfig
  48. except ImportError:
  49. print("FixedSparsityConfig Module is not installed!")
  50. return False
  51. return True
  52. def test_variablesparsityconfig_module_availability():
  53. try:
  54. from deepspeed.ops.sparse_attention import VariableSparsityConfig
  55. except ImportError:
  56. print("VariableSparsityConfig Module is not installed!")
  57. return False
  58. return True
  59. def test_bigbirdsparsityconfig_module_availability():
  60. try:
  61. from deepspeed.ops.sparse_attention import BigBirdSparsityConfig
  62. except ImportError:
  63. print("BigBirdSparsityConfig Module is not installed!")
  64. return False
  65. return True
  66. def test_bslongformersparsityconfig_module_availability():
  67. try:
  68. from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig
  69. except ImportError:
  70. print("BSLongformerSparsityConfig Module is not installed!")
  71. return False
  72. return True
  73. def test_sparseselfattention_module_availability():
  74. try:
  75. from deepspeed.ops.sparse_attention import SparseSelfAttention
  76. except ImportError:
  77. print("SparseSelfAttention Module is not installed!")
  78. return False
  79. return True
  80. def test_bertsparseselfattention_module_availability():
  81. try:
  82. from deepspeed.ops.sparse_attention import BertSparseSelfAttention
  83. except ImportError:
  84. print("BertSparseSelfAttention Module is not installed!")
  85. return False
  86. return True
  87. def test_sparseattentionutils_availability():
  88. try:
  89. from deepspeed.ops.sparse_attention import SparseAttentionUtils
  90. except ImportError:
  91. print("SparseAttentionUtils Module is not installed!")
  92. return False
  93. return True
  94. def test_cpp_utils_availability():
  95. try:
  96. from deepspeed.ops.sparse_attention import cpp_utils
  97. except ImportError:
  98. print("Sparse Attention cpp_utils Module is not installed!")
  99. return False
  100. return True
  101. def dense_to_sparse(w, mask, block):
  102. """Converts dense matrix with explicit zeros to sparse matrix
  103. """
  104. Z = w.size(0)
  105. ret = torch.empty((Z, mask.sum(), block, block), dtype=w.dtype, device=w.device)
  106. nnz = mask.nonzero()
  107. h, i, j = nnz[:, 0], nnz[:, 1], nnz[:, 2]
  108. for zz in range(Z):
  109. for idx, (hh, ii, jj) in enumerate(zip(h, i, j)):
  110. ret[zz, idx, :, :] = w[zz, hh, ii*block: (ii+1)*block, jj*block: (jj+1)*block]
  111. return ret
  112. def sparse_to_dense(w, mask, block, zero=0):
  113. """Converts sparse matrix to dense matrix with explicit zeros
  114. """
  115. maskedw = w.clone()
  116. for bz, wz in enumerate(range(0, w.size(0))):
  117. for bh, wh in enumerate(range(0, w.size(1))):
  118. for bi, wi in enumerate(range(0, w.size(2), block)):
  119. for bj, wj in enumerate(range(0, w.size(3), block)):
  120. if mask[bh, bi, bj] == 0:
  121. maskedw[wz, wh, wi:wi + block, wj:wj + block] = zero
  122. #maskedw[wz, wh, wi : wi+block, wj : wj+block] *= mask[bh, bi, bj]
  123. return maskedw
  124. def allclose(x, y):
  125. assert x.dtype == y.dtype
  126. rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
  127. return torch.allclose(x, y, rtol=rtol, atol=atol)
  128. def make_layout(rho, shape):
  129. probs = torch.Tensor([rho, 1 - rho])
  130. generator = torch.distributions.categorical.Categorical(probs)
  131. layout = generator.sample(shape)
  132. return layout
  133. def run_softmax_reference(x, scale, dx, kp_mask, attn_mask, layout, block):
  134. x = sparse_to_dense(x, layout, block, zero=float('-inf'))
  135. x.retain_grad()
  136. if kp_mask is not None:
  137. bcattn_mask = attn_mask[None, None, :, :] + torch.zeros_like(x)
  138. x[bcattn_mask == 0] = float('-inf')
  139. y = torch.softmax(x * scale + kp_mask[:, None, None, :], -1)
  140. else:
  141. y = torch.softmax(x * scale, -1)
  142. y.backward(dx)
  143. dx = x.grad.clone()
  144. dx = dense_to_sparse(dx, layout, block)
  145. y = dense_to_sparse(y, layout, block)
  146. return y, dx
  147. def run_softmax_sparse(x, scale, dx, kp_mask, attn_mask, layout, block):
  148. from deepspeed.ops.sparse_attention import Softmax
  149. sparse_softmax = Softmax(layout, block, bench=False)
  150. dx = dense_to_sparse(dx, layout, block)
  151. x = dense_to_sparse(x, layout, block)
  152. x.retain_grad()
  153. y = sparse_softmax(x,
  154. scale=scale,
  155. key_padding_mask=kp_mask,
  156. key_padding_mask_mode='add',
  157. attn_mask=attn_mask,
  158. attn_mask_mode='mul')
  159. y.backward(dx)
  160. dx = x.grad.clone()
  161. x.grad.zero_()
  162. return x, dx
  163. def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layout=None):
  164. if layout is None:
  165. layout = make_layout(rho, (H, M // block, N // block))
  166. if dense_x:
  167. x = torch.rand((Z, H, M, N), dtype=dtype, requires_grad=True, device='cuda')
  168. else:
  169. x = torch.rand((Z,
  170. layout.sum(),
  171. block,
  172. block),
  173. dtype=dtype,
  174. requires_grad=True,
  175. device='cuda')
  176. dx = torch.rand_like(x)
  177. bool_attn_mask = torch.randint(low=0,
  178. high=2,
  179. size=(N,
  180. N),
  181. dtype=torch.bool,
  182. requires_grad=False,
  183. device='cuda')
  184. fp_attn_mask = bool_attn_mask.type(dtype)
  185. kp_mask = torch.randint(low=0,
  186. high=2,
  187. size=(Z,
  188. N),
  189. dtype=dtype,
  190. requires_grad=False,
  191. device='cuda')
  192. kp_mask[kp_mask == 1.] = float('-inf')
  193. return layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask
  194. def _skip_on_cuda_compatability():
  195. pytest.skip("Skip these tests for now until we get our docker image fixed.")
  196. if torch.cuda.get_device_capability()[0] != 7:
  197. pytest.skip("needs compute capability 7; v100")
  198. cuda_major = int(torch.version.cuda.split('.')[0]) * 10
  199. cuda_minor = int(torch.version.cuda.split('.')[1])
  200. cuda_version = cuda_major + cuda_minor
  201. if cuda_version != 101 and cuda_version != 102:
  202. pytest.skip("requires cuda 10.1 or 10.2")
  203. @pytest.mark.parametrize("block", [16, 32])
  204. @pytest.mark.parametrize("width", [256, 576])
  205. @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
  206. def test_softmax(block, width, dtype):
  207. _skip_on_cuda_compatability()
  208. Z = 2
  209. H = 4
  210. scale = 0.4
  211. rho = 0.4
  212. M = N = width
  213. layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask = init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, layout=None)
  214. ref_y, ref_dx = run_softmax_reference(x, scale, dx, kp_mask, bool_attn_mask, layout, block)
  215. st_y, st_dx = run_softmax_sparse(x, scale, dx, kp_mask, fp_attn_mask, layout, block)
  216. assert allclose(ref_y, st_y)
  217. assert allclose(ref_dx, st_dx)
  218. def run_matmul_reference(x, w, mode, trans_a, trans_b, layout, block, dy):
  219. x = sparse_to_dense(x, layout, block) if mode == 'dsd' else x
  220. w = sparse_to_dense(w, layout, block) if mode == 'dds' else w
  221. x.retain_grad()
  222. w.retain_grad()
  223. xx = x.transpose(2, 3) if trans_a else x
  224. ww = w.transpose(2, 3) if trans_b else w
  225. y = torch.matmul(xx, ww)
  226. y = sparse_to_dense(y, layout, block) if mode == 'sdd' else y
  227. y.backward(dy)
  228. dx = x.grad.clone()
  229. dw = w.grad.clone()
  230. x.grad.zero_()
  231. w.grad.zero_()
  232. y = dense_to_sparse(y, layout, block) if mode == 'sdd' else y
  233. dx = dense_to_sparse(dx, layout, block) if mode == 'dsd' else dx
  234. dw = dense_to_sparse(dw, layout, block) if mode == 'dds' else dw
  235. return y, dx, dw
  236. def run_matmul_sparse(x, w, mode, trans_a, trans_b, layout, block, dy):
  237. from deepspeed.ops.sparse_attention import MatMul
  238. x = dense_to_sparse(x, layout, block) if mode == 'dsd' else x
  239. w = dense_to_sparse(w, layout, block) if mode == 'dds' else w
  240. dy = dense_to_sparse(dy, layout, block) if mode == 'sdd' else dy
  241. op = MatMul(layout, block, mode, trans_a=trans_a, trans_b=trans_b)
  242. x.retain_grad()
  243. w.retain_grad()
  244. y = op(x, w)
  245. y.backward(dy)
  246. dx = x.grad.clone()
  247. dw = w.grad.clone()
  248. x.grad.zero_()
  249. return y, dx, dw
  250. def init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout):
  251. torch.manual_seed(1)
  252. AS0 = K if trans_a else M
  253. AS1 = M if trans_a else K
  254. BS0 = N if trans_b else K
  255. BS1 = K if trans_b else N
  256. shape = {'sdd': (M, N), 'dsd': (AS0, AS1), 'dds': (BS0, BS1)}[mode]
  257. x = torch.rand((Z, H, AS0, AS1), dtype=dtype, requires_grad=True, device='cuda')
  258. w = torch.rand((Z, H, BS0, BS1), dtype=dtype, requires_grad=True, device='cuda')
  259. dy = torch.rand((Z, H, M, N), dtype=dtype, device='cuda')
  260. if layout is None:
  261. layout = make_layout(rho, (H, shape[0] // block, shape[1] // block))
  262. else:
  263. assert list(layout.shape) == [H, shape[0] // block, shape[1] // block]
  264. x.retain_grad()
  265. w.retain_grad()
  266. return x, w, dy, shape, layout
  267. testdata = [
  268. (16, dtype, mode, trans_a, trans_b)\
  269. for dtype in [torch.float16, torch.float32]\
  270. for mode in ['sdd', 'dsd', 'dds']\
  271. for trans_a in [False, True]\
  272. for trans_b in [False, True]\
  273. ] + [
  274. (block, torch.float16, mode, False, False)\
  275. for block in [16, 32, 64]\
  276. for mode in ['sdd', 'dsd', 'dds']\
  277. ]
  278. @pytest.mark.parametrize("block, dtype, mode, trans_a, trans_b", testdata)
  279. def test_matmul(block, dtype, mode, trans_a, trans_b):
  280. _skip_on_cuda_compatability()
  281. Z = 3
  282. H = 2
  283. M = 128
  284. N = 256
  285. K = 192
  286. rho = 0.5
  287. x, w, dy, shape, layout = init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout=None)
  288. ref_y, ref_dx, ref_dw = run_matmul_reference(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
  289. st_y, st_dx, st_dw = run_matmul_sparse(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
  290. assert allclose(ref_y, st_y)
  291. assert allclose(ref_dx, st_dx)
  292. assert allclose(ref_dw, st_dw)