test_sparse_attention.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # DeepSpeed note, some parts of code taken & adapted from commit c368a9fd1b2c9dee4cc94de9a6bb0be3d447be41
  2. # https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_softmax.py
  3. # https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_matmul.py
  4. # https://github.com/ptillet/torch-blocksparse/blob/master/tests/utils
  5. import pytest
  6. import torch
  7. def test_sparse_attention_module_availability():
  8. try:
  9. from deepspeed.ops import sparse_attention
  10. except ImportError:
  11. print("Sparse Attention Module is not installed!")
  12. return False
  13. return True
  14. def test_matmul_module_availability():
  15. try:
  16. from deepspeed.ops.sparse_attention import MatMul
  17. except ImportError:
  18. print("Sparse MatMul Module is not installed!")
  19. return False
  20. return True
  21. def test_softmax_module_availability():
  22. try:
  23. from deepspeed.ops.sparse_attention import Softmax
  24. except ImportError:
  25. print("Sparse Softmax Module is not installed!")
  26. return False
  27. return True
  28. def test_sparsityconfig_module_availability():
  29. try:
  30. from deepspeed.ops.sparse_attention import SparsityConfig
  31. except ImportError:
  32. print("SparsityConfig Module is not installed!")
  33. return False
  34. return True
  35. def test_densesparsityconfig_module_availability():
  36. try:
  37. from deepspeed.ops.sparse_attention import DenseSparsityConfig
  38. except ImportError:
  39. print("DenseSparsityConfig Module is not installed!")
  40. return False
  41. return True
  42. def test_fixedsparsityconfig_module_availability():
  43. try:
  44. from deepspeed.ops.sparse_attention import FixedSparsityConfig
  45. except ImportError:
  46. print("FixedSparsityConfig Module is not installed!")
  47. return False
  48. return True
  49. def test_variablesparsityconfig_module_availability():
  50. try:
  51. from deepspeed.ops.sparse_attention import VariableSparsityConfig
  52. except ImportError:
  53. print("VariableSparsityConfig Module is not installed!")
  54. return False
  55. return True
  56. def test_bigbirdsparsityconfig_module_availability():
  57. try:
  58. from deepspeed.ops.sparse_attention import BigBirdSparsityConfig
  59. except ImportError:
  60. print("BigBirdSparsityConfig Module is not installed!")
  61. return False
  62. return True
  63. def test_bslongformersparsityconfig_module_availability():
  64. try:
  65. from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig
  66. except ImportError:
  67. print("BSLongformerSparsityConfig Module is not installed!")
  68. return False
  69. return True
  70. def test_sparseselfattention_module_availability():
  71. try:
  72. from deepspeed.ops.sparse_attention import SparseSelfAttention
  73. except ImportError:
  74. print("SparseSelfAttention Module is not installed!")
  75. return False
  76. return True
  77. def test_bertsparseselfattention_module_availability():
  78. try:
  79. from deepspeed.ops.sparse_attention import BertSparseSelfAttention
  80. except ImportError:
  81. print("BertSparseSelfAttention Module is not installed!")
  82. return False
  83. return True
  84. def test_sparseattentionutils_availability():
  85. try:
  86. from deepspeed.ops.sparse_attention import SparseAttentionUtils
  87. except ImportError:
  88. print("SparseAttentionUtils Module is not installed!")
  89. return False
  90. return True
  91. def test_cpp_utils_availability():
  92. try:
  93. from deepspeed.ops.sparse_attention import cpp_utils
  94. except ImportError:
  95. print("Sparse Attention cpp_utils Module is not installed!")
  96. return False
  97. return True
  98. def dense_to_sparse(w, mask, block):
  99. """Converts dense matrix with explicit zeros to sparse matrix
  100. """
  101. Z = w.size(0)
  102. ret = torch.empty((Z, mask.sum(), block, block), dtype=w.dtype, device=w.device)
  103. nnz = mask.nonzero()
  104. h, i, j = nnz[:, 0], nnz[:, 1], nnz[:, 2]
  105. for zz in range(Z):
  106. for idx, (hh, ii, jj) in enumerate(zip(h, i, j)):
  107. ret[zz, idx, :, :] = w[zz, hh, ii*block: (ii+1)*block, jj*block: (jj+1)*block]
  108. return ret
  109. def sparse_to_dense(w, mask, block, zero=0):
  110. """Converts sparse matrix to dense matrix with explicit zeros
  111. """
  112. maskedw = w.clone()
  113. for bz, wz in enumerate(range(0, w.size(0))):
  114. for bh, wh in enumerate(range(0, w.size(1))):
  115. for bi, wi in enumerate(range(0, w.size(2), block)):
  116. for bj, wj in enumerate(range(0, w.size(3), block)):
  117. if mask[bh, bi, bj] == 0:
  118. maskedw[wz, wh, wi:wi + block, wj:wj + block] = zero
  119. #maskedw[wz, wh, wi : wi+block, wj : wj+block] *= mask[bh, bi, bj]
  120. return maskedw
  121. def allclose(x, y):
  122. assert x.dtype == y.dtype
  123. rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
  124. return torch.allclose(x, y, rtol=rtol, atol=atol)
  125. def make_layout(rho, shape):
  126. probs = torch.Tensor([rho, 1 - rho])
  127. generator = torch.distributions.categorical.Categorical(probs)
  128. layout = generator.sample(shape)
  129. return layout
  130. def run_softmax_reference(x, scale, dx, kp_mask, attn_mask, layout, block):
  131. x = sparse_to_dense(x, layout, block, zero=float('-inf'))
  132. x.retain_grad()
  133. if kp_mask is not None:
  134. bcattn_mask = attn_mask[None, None, :, :] + torch.zeros_like(x)
  135. x[bcattn_mask == 0] = float('-inf')
  136. y = torch.softmax(x * scale + kp_mask[:, None, None, :], -1)
  137. else:
  138. y = torch.softmax(x * scale, -1)
  139. y.backward(dx)
  140. dx = x.grad.clone()
  141. dx = dense_to_sparse(dx, layout, block)
  142. y = dense_to_sparse(y, layout, block)
  143. return y, dx
  144. def run_softmax_sparse(x, scale, dx, kp_mask, attn_mask, layout, block):
  145. from deepspeed.ops.sparse_attention import Softmax
  146. sparse_softmax = Softmax(layout, block, bench=False)
  147. dx = dense_to_sparse(dx, layout, block)
  148. x = dense_to_sparse(x, layout, block)
  149. x.retain_grad()
  150. y = sparse_softmax(x,
  151. scale=scale,
  152. key_padding_mask=kp_mask,
  153. key_padding_mask_mode='add',
  154. attn_mask=attn_mask,
  155. attn_mask_mode='mul')
  156. y.backward(dx)
  157. dx = x.grad.clone()
  158. x.grad.zero_()
  159. return x, dx
  160. def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layout=None):
  161. if layout is None:
  162. layout = make_layout(rho, (H, M // block, N // block))
  163. if dense_x:
  164. x = torch.rand((Z, H, M, N), dtype=dtype, requires_grad=True, device='cuda')
  165. else:
  166. x = torch.rand((Z,
  167. layout.sum(),
  168. block,
  169. block),
  170. dtype=dtype,
  171. requires_grad=True,
  172. device='cuda')
  173. dx = torch.rand_like(x)
  174. bool_attn_mask = torch.randint(low=0,
  175. high=2,
  176. size=(N,
  177. N),
  178. dtype=torch.bool,
  179. requires_grad=False,
  180. device='cuda')
  181. fp_attn_mask = bool_attn_mask.type(dtype)
  182. kp_mask = torch.randint(low=0,
  183. high=2,
  184. size=(Z,
  185. N),
  186. dtype=dtype,
  187. requires_grad=False,
  188. device='cuda')
  189. kp_mask[kp_mask == 1.] = float('-inf')
  190. return layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask
  191. def _skip_on_cuda_compatability():
  192. pytest.skip("Skip these tests for now until we get our docker image fixed.")
  193. if torch.cuda.get_device_capability()[0] != 7:
  194. pytest.skip("needs compute capability 7; v100")
  195. cuda_major = int(torch.version.cuda.split('.')[0]) * 10
  196. cuda_minor = int(torch.version.cuda.split('.')[1])
  197. cuda_version = cuda_major + cuda_minor
  198. if cuda_version != 101 and cuda_version != 102:
  199. pytest.skip("requires cuda 10.1 or 10.2")
  200. @pytest.mark.parametrize("block", [16, 32])
  201. @pytest.mark.parametrize("width", [256, 576])
  202. @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
  203. def test_softmax(block, width, dtype):
  204. _skip_on_cuda_compatability()
  205. Z = 2
  206. H = 4
  207. scale = 0.4
  208. rho = 0.4
  209. M = N = width
  210. layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask = init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, layout=None)
  211. ref_y, ref_dx = run_softmax_reference(x, scale, dx, kp_mask, bool_attn_mask, layout, block)
  212. st_y, st_dx = run_softmax_sparse(x, scale, dx, kp_mask, fp_attn_mask, layout, block)
  213. assert allclose(ref_y, st_y)
  214. assert allclose(ref_dx, st_dx)
  215. def run_matmul_reference(x, w, mode, trans_a, trans_b, layout, block, dy):
  216. x = sparse_to_dense(x, layout, block) if mode == 'dsd' else x
  217. w = sparse_to_dense(w, layout, block) if mode == 'dds' else w
  218. x.retain_grad()
  219. w.retain_grad()
  220. xx = x.transpose(2, 3) if trans_a else x
  221. ww = w.transpose(2, 3) if trans_b else w
  222. y = torch.matmul(xx, ww)
  223. y = sparse_to_dense(y, layout, block) if mode == 'sdd' else y
  224. y.backward(dy)
  225. dx = x.grad.clone()
  226. dw = w.grad.clone()
  227. x.grad.zero_()
  228. w.grad.zero_()
  229. y = dense_to_sparse(y, layout, block) if mode == 'sdd' else y
  230. dx = dense_to_sparse(dx, layout, block) if mode == 'dsd' else dx
  231. dw = dense_to_sparse(dw, layout, block) if mode == 'dds' else dw
  232. return y, dx, dw
  233. def run_matmul_sparse(x, w, mode, trans_a, trans_b, layout, block, dy):
  234. from deepspeed.ops.sparse_attention import MatMul
  235. x = dense_to_sparse(x, layout, block) if mode == 'dsd' else x
  236. w = dense_to_sparse(w, layout, block) if mode == 'dds' else w
  237. dy = dense_to_sparse(dy, layout, block) if mode == 'sdd' else dy
  238. op = MatMul(layout, block, mode, trans_a=trans_a, trans_b=trans_b)
  239. x.retain_grad()
  240. w.retain_grad()
  241. y = op(x, w)
  242. y.backward(dy)
  243. dx = x.grad.clone()
  244. dw = w.grad.clone()
  245. x.grad.zero_()
  246. return y, dx, dw
  247. def init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout):
  248. torch.manual_seed(1)
  249. AS0 = K if trans_a else M
  250. AS1 = M if trans_a else K
  251. BS0 = N if trans_b else K
  252. BS1 = K if trans_b else N
  253. shape = {'sdd': (M, N), 'dsd': (AS0, AS1), 'dds': (BS0, BS1)}[mode]
  254. x = torch.rand((Z, H, AS0, AS1), dtype=dtype, requires_grad=True, device='cuda')
  255. w = torch.rand((Z, H, BS0, BS1), dtype=dtype, requires_grad=True, device='cuda')
  256. dy = torch.rand((Z, H, M, N), dtype=dtype, device='cuda')
  257. if layout is None:
  258. layout = make_layout(rho, (H, shape[0] // block, shape[1] // block))
  259. else:
  260. assert list(layout.shape) == [H, shape[0] // block, shape[1] // block]
  261. x.retain_grad()
  262. w.retain_grad()
  263. return x, w, dy, shape, layout
  264. testdata = [
  265. (16, dtype, mode, trans_a, trans_b)\
  266. for dtype in [torch.float16, torch.float32]\
  267. for mode in ['sdd', 'dsd', 'dds']\
  268. for trans_a in [False, True]\
  269. for trans_b in [False, True]\
  270. ] + [
  271. (block, torch.float16, mode, False, False)\
  272. for block in [16, 32, 64]\
  273. for mode in ['sdd', 'dsd', 'dds']\
  274. ]
  275. @pytest.mark.parametrize("block, dtype, mode, trans_a, trans_b", testdata)
  276. def test_matmul(block, dtype, mode, trans_a, trans_b):
  277. _skip_on_cuda_compatability()
  278. Z = 3
  279. H = 2
  280. M = 128
  281. N = 256
  282. K = 192
  283. rho = 0.5
  284. x, w, dy, shape, layout = init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout=None)
  285. ref_y, ref_dx, ref_dw = run_matmul_reference(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
  286. st_y, st_dx, st_dw = run_matmul_sparse(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
  287. assert allclose(ref_y, st_y)
  288. assert allclose(ref_dx, st_dx)
  289. assert allclose(ref_dw, st_dw)