matmul_ext.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import triton
  6. import os
  7. from filelock import FileLock
  8. import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel
  9. import pickle
  10. from io import open
  11. import deepspeed
  12. from pathlib import Path
  13. import atexit
  14. # -----------------------------------------------------------------------------
  15. # util class/functions for triton
  16. def _default_cache_dir():
  17. return os.path.join(Path.home(), ".triton", "autotune")
  18. def bias_add_activation(C, bias=None, activation=""):
  19. if bias is not None:
  20. C += bias
  21. # activation
  22. if activation == "relu":
  23. relu = torch.nn.Relu()
  24. C = relu(C)
  25. elif activation == "leaky_relu":
  26. leaky_relu = torch.nn.LeakyReLU(0.01)
  27. C = leaky_relu(C)
  28. elif activation == "gelu":
  29. sigmoid = torch.nn.Sigmoid()
  30. C = sigmoid(1.702 * C) * C
  31. elif activation == "sigmoid":
  32. sigmoid = torch.nn.Sigmoid()
  33. C = sigmoid(C)
  34. return C
  35. class AutotuneCacheManager:
  36. """
  37. Cache manager for autotune
  38. """
  39. def __init__(self, key):
  40. self.key = key
  41. self.file_path = None
  42. self.lock_path = None
  43. # if caching is enabled, get the lock and bin path
  44. self.cache_dir = os.environ.get('TRITON_CACHE_DIR', _default_cache_dir())
  45. if self.cache_dir:
  46. os.makedirs(self.cache_dir, exist_ok=True)
  47. if self.cache_dir:
  48. self.file_path = os.path.join(self.cache_dir, self.key + ".pickle")
  49. self.lock_path = self.file_path + ".lock"
  50. def has_file(self):
  51. return self.file_path and os.path.exists(self.file_path)
  52. def put(self, table):
  53. if self.file_path:
  54. assert self.lock_path is not None
  55. with FileLock(self.lock_path):
  56. with open(self.file_path + ".tmp", 'wb') as handle:
  57. pickle.dump(table, handle)
  58. os.rename(self.file_path + ".tmp", self.file_path)
  59. def load(self):
  60. if os.path.exists(self.file_path):
  61. with open(self.file_path, 'rb') as handle:
  62. loaded_dict = pickle.load(handle)
  63. return loaded_dict
  64. else:
  65. return None
  66. # -----------------------------------------------------------------------------
  67. # triton matmul class
  68. class MatmulExt(torch.autograd.Function):
  69. """
  70. a wrapper class that can call different triton matmul kernels depending on the input parameters
  71. """
  72. @staticmethod
  73. def forward(A, B, bias=None, activation="", use_triton=True, update_autotune_table=False):
  74. """
  75. A: input, activation matrix A
  76. B: input, weight matrix B
  77. """
  78. matmul = None
  79. quantize_activation = False
  80. Batch = 0
  81. if len(A.shape) == 3: # if A is 3d-tensor where batch index is given as 0-axis
  82. assert A.is_contiguous(), "matrix A must be contiguous"
  83. Batch, M, K = A.shape
  84. A = A.view(-1, K)
  85. # fp16 activation and fp16 weight matmul into fp16 output
  86. matmul = fp16_matmul
  87. C = matmul.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
  88. if matmul and update_autotune_table:
  89. matmul._update_autotune_table()
  90. if Batch > 0:
  91. C = C.view(Batch, M, -1)
  92. return C
  93. class TritonMatmul(torch.autograd.Function):
  94. """
  95. triton matmul kernel superclass
  96. """
  97. def __init__(self):
  98. pass
  99. @staticmethod
  100. def _ref_forward(A, B, ref_dtype=torch.float32):
  101. C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
  102. return C
  103. @staticmethod
  104. def _read_autotune_table(cache_key, triton_kernel):
  105. cache_manager = AutotuneCacheManager(cache_key)
  106. table = cache_manager.load()
  107. if table:
  108. triton_kernel.cache = table
  109. @staticmethod
  110. def _write_autotune_table(cache_key, triton_kernel):
  111. cache_manager = AutotuneCacheManager(cache_key)
  112. cache_manager.put(triton_kernel.cache)
  113. @staticmethod
  114. def _update_autotune_table(cache_key, triton_kernel):
  115. cache_manager = AutotuneCacheManager(cache_key)
  116. autotune_table = cache_manager.load()
  117. if autotune_table is None:
  118. autotune_table = dict()
  119. autotune_table.update(triton_kernel.cache) # always overwrite with the new autotune results
  120. cache_manager = AutotuneCacheManager(cache_key)
  121. cache_manager.put(autotune_table)
  122. @staticmethod
  123. def forward(
  124. A,
  125. B,
  126. ref_dtype=torch.float32, # fp32 only
  127. bias=None,
  128. activation=""):
  129. C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
  130. C = bias_add_activation(C, bias, activation)
  131. return C
  132. class Fp16Matmul(TritonMatmul):
  133. """
  134. fp16 matrix multiplication kernel
  135. dtypes: fp16 x fp16 = fp16
  136. """
  137. _2d_kernel = triton_matmul_kernel._fp_matmul
  138. _4d_kernel = triton_matmul_kernel.matmul_4d_kernel
  139. _cache_stride = 32
  140. def __init__(self, read_cache=True):
  141. super().__init__()
  142. if read_cache:
  143. __class__._read_autotune_table()
  144. def skip_autotune(self):
  145. __class__._2d_kernel.configs = [__class__._2d_kernel.configs[0]]
  146. __class__._4d_kernel.configs = [__class__._4d_kernel.configs[0]]
  147. @staticmethod
  148. def forward(A, B, use_triton=True, bias=None, activation=""):
  149. if use_triton:
  150. device = A.device
  151. # handle non-contiguous inputs if necessary
  152. if A.stride(0) > 1 and A.stride(1) > 1:
  153. A = A.contiguous()
  154. if B.stride(0) > 1 and B.stride(1) > 1:
  155. B = B.contiguous()
  156. # checks constraints
  157. assert A.shape[1] == B.shape[0], "incompatible dimensions"
  158. M, K = A.shape
  159. _, N = B.shape
  160. # allocates output
  161. C = torch.empty((M, N), device=device, dtype=A.dtype)
  162. # accumulator types
  163. ACC_TYPE = triton.language.float32 if A.dtype in [torch.float16, torch.bfloat16, torch.float32
  164. ] else triton.language.int32
  165. # launch kernel
  166. grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
  167. __class__._2d_kernel[grid](A,
  168. B,
  169. C,
  170. M,
  171. N,
  172. K,
  173. bias,
  174. A.stride(0),
  175. A.stride(1),
  176. B.stride(0),
  177. B.stride(1),
  178. C.stride(0),
  179. C.stride(1),
  180. M // __class__._cache_stride,
  181. N // __class__._cache_stride,
  182. K // __class__._cache_stride,
  183. GROUP_M=8,
  184. ACC_TYPE=ACC_TYPE,
  185. BIAS_ADD=(0 if bias is None else 1),
  186. ACTIVATION=activation)
  187. else:
  188. C = torch.matmul(A, B)
  189. return C
  190. @staticmethod
  191. def _matmul_4d(a, b):
  192. assert a.shape[-1] == b.shape[-2], "incompatible dimensions"
  193. assert a.is_contiguous(), "matrix A must be contiguous"
  194. assert b.is_contiguous(), "matrix B must be contiguous"
  195. B, H, M, K = a.shape
  196. B, H, K, N = b.shape
  197. assert K > 1, "inner-product dimension K should be larger than 1"
  198. c = torch.empty((B, H, M, N), device=a.device, dtype=a.dtype)
  199. grid = lambda META: (
  200. triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
  201. H,
  202. B,
  203. )
  204. __class__._4d_kernel[grid](
  205. a,
  206. b,
  207. c,
  208. M,
  209. N,
  210. K,
  211. M // __class__._cache_stride,
  212. N // __class__._cache_stride,
  213. K // __class__._cache_stride,
  214. a.stride(0),
  215. a.stride(1),
  216. a.stride(2),
  217. a.stride(3),
  218. b.stride(0),
  219. b.stride(1),
  220. b.stride(2),
  221. b.stride(3),
  222. c.stride(0),
  223. c.stride(1),
  224. c.stride(2),
  225. c.stride(3),
  226. scale=-1.0,
  227. MASK=False,
  228. )
  229. return c
  230. @staticmethod
  231. def _score_4d_matmul(input, head_size, input_mask, scale=-1.0):
  232. assert input.is_contiguous(), "matrix input must be contiguous"
  233. batches = input.shape[0]
  234. d_model = input.shape[-1] // 3
  235. num_of_heads = d_model // head_size
  236. q = input[:, :, :d_model]
  237. k = input[:, :, d_model:d_model * 2]
  238. q = q.view(batches, -1, num_of_heads, head_size)
  239. k = k.view(batches, -1, num_of_heads, head_size)
  240. # checks constraints
  241. assert q.shape == k.shape, "incompatible dimensions"
  242. B, M, H, K = q.shape
  243. B, N, H, K = k.shape
  244. assert K > 1, "inner-product dimension K should be larger than 1"
  245. # allocates output
  246. output = torch.empty((B, H, M, N), device=q.device, dtype=q.dtype)
  247. grid = lambda META: (
  248. triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
  249. H,
  250. B,
  251. )
  252. __class__._4d_kernel[grid](
  253. q,
  254. k,
  255. output,
  256. M,
  257. N,
  258. K,
  259. M // __class__._cache_stride,
  260. N // __class__._cache_stride,
  261. K // __class__._cache_stride,
  262. q.stride(0),
  263. q.stride(2),
  264. q.stride(1),
  265. q.stride(3),
  266. k.stride(0),
  267. k.stride(2),
  268. k.stride(3),
  269. k.stride(1),
  270. output.stride(0),
  271. output.stride(1),
  272. output.stride(2),
  273. output.stride(3),
  274. scale=scale,
  275. MASK=False,
  276. )
  277. return output
  278. @staticmethod
  279. def _context_4d_matmul(prob, input, head_size):
  280. assert prob.is_contiguous(), "matrix prob must be contiguous"
  281. assert input.is_contiguous(), "matrix input must be contiguous"
  282. batches = input.shape[0]
  283. d_model = input.shape[-1] // 3
  284. num_of_heads = d_model // head_size
  285. v = input[:, :, d_model * 2:]
  286. v = v.view(batches, -1, num_of_heads, head_size)
  287. # checks constraints
  288. assert (prob.shape[0] == v.shape[0] and prob.shape[1] == v.shape[2] and prob.shape[2] == v.shape[1]
  289. and prob.shape[3] == v.shape[1]), "incompatible dimensions"
  290. B, H, M, K = prob.shape
  291. B, K, H, N = v.shape
  292. assert K > 1, "inner-product dimension K should be larger than 1"
  293. # allocates output
  294. output = torch.empty((B, M, H, N), device=v.device, dtype=v.dtype)
  295. grid = lambda META: (
  296. triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
  297. H,
  298. B,
  299. )
  300. __class__._4d_kernel[grid](
  301. prob,
  302. v,
  303. output,
  304. M,
  305. N,
  306. K,
  307. M // __class__._cache_stride,
  308. N // __class__._cache_stride,
  309. K // __class__._cache_stride,
  310. prob.stride(0),
  311. prob.stride(1),
  312. prob.stride(2),
  313. prob.stride(3),
  314. v.stride(0),
  315. v.stride(2),
  316. v.stride(1),
  317. v.stride(3),
  318. # Here we also transpose the output when writing to memory.
  319. output.stride(0),
  320. output.stride(2),
  321. output.stride(1),
  322. output.stride(3),
  323. scale=-1,
  324. MASK=False,
  325. )
  326. return output.view(batches, -1, d_model)
  327. @staticmethod
  328. def _ref_forward(A, B, ref_dtype=torch.float32, bias=None, activation=""):
  329. C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
  330. C = bias_add_activation(C, bias, activation)
  331. return C
  332. @staticmethod
  333. def _check_parity(A,
  334. B,
  335. output_dtype,
  336. SA=None,
  337. SB=None,
  338. qblock_size=None,
  339. ref_dtype=torch.float32,
  340. tol=0.01,
  341. use_triton=True,
  342. bias=None,
  343. activation=""):
  344. torch_output = __class__._ref_forward(A, B, ref_dtype=ref_dtype, bias=bias, activation=activation)
  345. triton_output = __class__.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
  346. assert torch.allclose(triton_output.cpu().type(torch_output.dtype), torch_output.cpu(), rtol=tol)
  347. print(f"{__class__.__name__}: PASSed the parity check")
  348. return triton_output, torch_output
  349. @staticmethod
  350. def _read_autotune_table():
  351. TritonMatmul._read_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
  352. TritonMatmul._read_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
  353. @staticmethod
  354. def _write_autotune_table():
  355. TritonMatmul._write_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
  356. TritonMatmul._write_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
  357. @staticmethod
  358. def _update_autotune_table():
  359. TritonMatmul._update_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
  360. TritonMatmul._update_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
  361. # -----------------------------------------------------------------------------
  362. # mapping
  363. if deepspeed.HAS_TRITON:
  364. fp16_matmul = Fp16Matmul()
  365. matmul = MatmulExt.forward
  366. matmul_4d = fp16_matmul._matmul_4d
  367. score_4d_matmul = fp16_matmul._score_4d_matmul
  368. context_4d_matmul = fp16_matmul._context_4d_matmul
  369. else:
  370. fp16_matmul = None
  371. matmul = None
  372. matmul_4d = None
  373. score_4d_matmul = None
  374. context_4d_matmul = None
  375. @atexit.register
  376. def matmul_ext_update_autotune_table():
  377. if deepspeed.HAS_TRITON:
  378. fp16_matmul._update_autotune_table()