sparsity_config.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import torch
  5. import random
  6. class SparsityConfig:
  7. """Abstract Configuration class to store `sparsity configuration of a self attention layer`.
  8. It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality.
  9. """
  10. def __init__(self, num_heads, block=16, different_layout_per_head=False):
  11. """Initialize the Sparsity Pattern Config.
  12. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
  13. Arguments:
  14. num_heads: required: an integer determining number of attention heads of the layer.
  15. block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
  16. different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
  17. """
  18. self.num_heads = num_heads
  19. self.block = block
  20. self.different_layout_per_head = different_layout_per_head
  21. self.num_layout_heads = num_heads if different_layout_per_head else 1
  22. def setup_layout(self, seq_len):
  23. """Create layout tensor for the given sequence length
  24. Arguments:
  25. seq_len: required: an integer determining number of attention heads of the layer.
  26. Return:
  27. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout of all head; initialized with zero
  28. """
  29. if (seq_len % self.block != 0):
  30. raise ValueError(
  31. f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!'
  32. )
  33. num_blocks = seq_len // self.block
  34. # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
  35. layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64)
  36. return layout
  37. def check_and_propagate_first_head_layout(self, layout):
  38. """If all heads require same sparsity layout, it propagate first head layout to all heads
  39. Arguments:
  40. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  41. Return:
  42. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head
  43. """
  44. if not self.different_layout_per_head:
  45. layout[1:self.num_heads, :, :] = layout[0, :, :]
  46. return layout
  47. class DenseSparsityConfig(SparsityConfig):
  48. """Configuration class to store `Dense` configuration.
  49. In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
  50. """
  51. def __init__(self, num_heads, block=16, different_layout_per_head=False):
  52. """Initialize the Dense Sparsity Pattern Config.
  53. In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
  54. Arguments:
  55. num_heads: required: an integer determining number of attention heads of the layer.
  56. seq_len: required: an integer determining number of attention heads of the layer.
  57. different_layout_per_head: optional: this is just for the sake of consistency with other sparsity formats; can ignore it for DenseSparsityConfig
  58. """
  59. super().__init__(num_heads, block, different_layout_per_head)
  60. def make_layout(self, seq_len):
  61. """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
  62. Arguments:
  63. seq_len: required: an integer determining the underling sequence length; must be <= max sequence length
  64. Return:
  65. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; for dense everything is 1
  66. """
  67. layout = self.setup_layout(seq_len)
  68. layout[:, :, :] = 1
  69. return layout
  70. class FixedSparsityConfig(SparsityConfig):
  71. """Configuration class to store `Fixed` sparsity configuration.
  72. For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
  73. This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
  74. """
  75. def __init__(self,
  76. num_heads,
  77. block=16,
  78. different_layout_per_head=False,
  79. num_local_blocks=4,
  80. num_global_blocks=1,
  81. attention='bidirectional',
  82. horizontal_global_attention=False,
  83. num_different_global_patterns=1):
  84. """Initialize `Fixed` Sparsity Pattern Config.
  85. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
  86. Arguments:
  87. num_heads: required: an integer determining number of attention heads of the layer.
  88. block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
  89. different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
  90. num_local_blocks: optional: an integer determining the number of blocks in local attention window.
  91. num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention.
  92. attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
  93. horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
  94. num_different_global_patterns: optional: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different versions in which the first, Second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on num_local_blocks and num_global_blocks.
  95. """
  96. super().__init__(num_heads, block, different_layout_per_head)
  97. self.num_local_blocks = num_local_blocks
  98. if (num_local_blocks % num_global_blocks != 0):
  99. raise ValueError(
  100. f'Number of blocks in a local window, {num_local_blocks}, must be dividable by number of global blocks, {num_global_blocks}!'
  101. )
  102. self.num_global_blocks = num_global_blocks
  103. if (attention != 'unidirectional' and attention != 'bidirectional'):
  104. raise NotImplementedError(
  105. 'only \"uni/bi-directional\" attentions are supported for now!')
  106. self.attention = attention
  107. if (attention != 'bidirectional' and horizontal_global_attention):
  108. raise ValueError(
  109. 'only \"bi-directional\" attentions can support horizontal global attention!'
  110. )
  111. self.horizontal_global_attention = horizontal_global_attention
  112. if (num_different_global_patterns > 1 and not different_layout_per_head):
  113. raise ValueError(
  114. f'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.'
  115. )
  116. if (num_different_global_patterns > (num_local_blocks // num_global_blocks)):
  117. raise ValueError(
  118. f'Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, cannot be larger than number of local window blocks divided by number of global blocks, {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!'
  119. )
  120. self.num_different_global_patterns = num_different_global_patterns
  121. def set_local_layout(self, h, layout):
  122. """Sets local attention layout used by the given head in the sparse attention.
  123. Arguments:
  124. h: required: an integer determining head index
  125. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  126. Return:
  127. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set
  128. """
  129. num_blocks = layout.shape[1]
  130. for i in range(0, num_blocks, self.num_local_blocks):
  131. end = min(i + self.num_local_blocks, num_blocks)
  132. for row in range(i, end):
  133. for col in range(
  134. i,
  135. (row + 1 if self.attention == 'unidirectional' else end)):
  136. layout[h, row, col] = 1
  137. return layout
  138. def set_global_layout(self, h, layout):
  139. """Sets global attention layout used by the given head in the sparse attention.
  140. Currently we set global blocks starting from the last block of a local window to the first one. That means if a local window consists of 4 blocks and global attention size is one block, we use block #4 in each local window as global. If we have different layout per head, then other heads will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global attentions, multiple head may have same global attentions.
  141. Note) if horizontal_global_attention is set, global blocks will be set both horizontally and vertically.
  142. Arguments:
  143. h: required: an integer determining head index
  144. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  145. Return:
  146. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
  147. """
  148. num_blocks = layout.shape[1]
  149. first_global_block_idx = self.num_local_blocks - (
  150. 1 + h % self.num_different_global_patterns) * self.num_global_blocks
  151. # set all global blocks except the last one if (in last local window)
  152. end = num_blocks - (num_blocks % self.num_local_blocks)
  153. for i in range(first_global_block_idx, end, self.num_local_blocks):
  154. # vertical global attention
  155. first_row = 0 if self.attention == 'bidirectional' else i
  156. #(((i // self.num_local_blocks) + 1) * self.num_local_blocks)
  157. #if (first_row < num_blocks):
  158. layout[h, first_row:, i:i + self.num_global_blocks] = 1
  159. # horizontal global attention; only in bidirectional attention
  160. if (self.horizontal_global_attention):
  161. layout[h, i:i + self.num_global_blocks, :] = 1
  162. # set last global blocks; handle possible short last local window
  163. if (end < num_blocks):
  164. start = min(end + first_global_block_idx,
  165. num_blocks - self.num_global_blocks)
  166. end = start + self.num_global_blocks
  167. # vertical global attention
  168. first_row = 0 if self.attention == 'bidirectional' else start
  169. #(((start // self.num_local_blocks) + 1) * self.num_local_blocks)
  170. #if (first_row < num_blocks):
  171. layout[h, first_row:, start:end] = 1
  172. # horizontal global attention
  173. if (self.horizontal_global_attention):
  174. layout[h, start:end, :] = 1
  175. return layout
  176. def make_layout(self, seq_len):
  177. """Generates `Fixed` sparsity layout used by each head in the sparse attention.
  178. Arguments:
  179. seq_len: required: an integer determining number of attention heads of the layer.
  180. Return:
  181. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` sparsity layout of all head
  182. """
  183. layout = self.setup_layout(seq_len)
  184. for h in range(0, self.num_layout_heads):
  185. layout = self.set_local_layout(h, layout)
  186. layout = self.set_global_layout(h, layout)
  187. layout = self.check_and_propagate_first_head_layout(layout)
  188. return layout
  189. class VariableSparsityConfig(SparsityConfig):
  190. """Configuration class to store `Variable` sparsity configuration.
  191. This layout is an extension of FixedSparsityConfig in which:
  192. - user can set random layout; default value is zero means no random block
  193. - user can provide a list of local block sizes
  194. - user can provide a list of global block indices.
  195. For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
  196. This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
  197. """
  198. def __init__(self,
  199. num_heads,
  200. block=16,
  201. different_layout_per_head=False,
  202. num_random_blocks=0,
  203. local_window_blocks=[4],
  204. global_block_indices=[0],
  205. global_block_end_indices=None,
  206. attention='bidirectional',
  207. horizontal_global_attention=False):
  208. """Initialize `Variable` Sparsity Pattern Config.
  209. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
  210. Arguments:
  211. num_heads: required: an integer determining number of attention heads of the layer.
  212. block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
  213. different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. Currently this sparsity config can only assign single layout to all heads; needs to be extended for different layout per head.
  214. num_random_blocks: optional: an integer determining the number of random blocks in each block row.
  215. local_window_blocks: optional: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows.
  216. global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window.
  217. global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention.
  218. num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention.
  219. attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
  220. horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
  221. """
  222. super().__init__(num_heads, block, different_layout_per_head)
  223. self.num_random_blocks = num_random_blocks
  224. self.local_window_blocks = local_window_blocks
  225. self.global_block_indices = global_block_indices
  226. if (global_block_end_indices is not None):
  227. if (len(global_block_indices) != len(global_block_end_indices)):
  228. raise ValueError(
  229. f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!'
  230. )
  231. for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)):
  232. if start_idx >= end_idx:
  233. raise ValueError(
  234. f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!'
  235. )
  236. self.global_block_end_indices = global_block_end_indices
  237. if (attention != 'unidirectional' and attention != 'bidirectional'):
  238. raise NotImplementedError(
  239. 'only \"uni/bi-directional\" attentions are supported for now!')
  240. self.attention = attention
  241. if (attention != 'bidirectional' and horizontal_global_attention):
  242. raise ValueError(
  243. 'only \"bi-directional\" attentions can support horizontal global attention!'
  244. )
  245. self.horizontal_global_attention = horizontal_global_attention
  246. def set_random_layout(self, h, layout):
  247. """Sets random attention layout used by the given head in the sparse attention.
  248. Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout.
  249. Arguments:
  250. h: required: an integer determining head index
  251. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  252. Return:
  253. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set
  254. """
  255. num_blocks = layout.shape[1]
  256. if (num_blocks < self.num_random_blocks):
  257. raise ValueError(
  258. f'Number of random blocks, {self.num_random_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!'
  259. )
  260. for row in range(0, num_blocks):
  261. rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
  262. layout[h, row, rnd_cols] = 1
  263. return layout
  264. def set_local_layout(self, h, layout):
  265. """Sets local attention layout used by the given head in the sparse attention.
  266. Arguments:
  267. h: required: an integer determining head index
  268. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  269. Return:
  270. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set
  271. """
  272. num_blocks = layout.shape[1]
  273. start_block_idx = 0
  274. end_block_idx = 0
  275. for block_size in self.local_window_blocks:
  276. end_block_idx += block_size
  277. end_block_idx = min(end_block_idx, num_blocks)
  278. for row in range(start_block_idx, end_block_idx):
  279. for col in range(
  280. start_block_idx,
  281. (row + 1 if self.attention == 'unidirectional' else end_block_idx)):
  282. layout[h, row, col] = 1
  283. start_block_idx += block_size
  284. # if there is any remaining not attended part, use the lats local window block size as local window for the remaining applicable local windows
  285. for i in range(start_block_idx, num_blocks, block_size):
  286. end_block_idx = min(i + block_size, num_blocks)
  287. for row in range(i, end_block_idx):
  288. for col in range(
  289. i,
  290. (row + 1 if self.attention == 'unidirectional' else end_block_idx)):
  291. layout[h, row, col] = 1
  292. return layout
  293. def set_global_layout(self, h, layout):
  294. """Sets global attention layout used by the given head in the sparse attention.
  295. Arguments:
  296. h: required: an integer determining head index
  297. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  298. Return:
  299. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
  300. """
  301. num_blocks = layout.shape[1]
  302. if (self.global_block_end_indices is None):
  303. for idx in self.global_block_indices:
  304. # if global block idx is in the range of the sequence blocks
  305. if (idx < num_blocks):
  306. #global rows
  307. if (self.horizontal_global_attention):
  308. layout[h, idx, :] = 1
  309. #global columns
  310. first_row = 0 if self.attention == 'bidirectional' else idx
  311. layout[h, first_row:, idx] = 1
  312. else:
  313. for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)):
  314. # if global block idx is in the range of the sequence blocks
  315. if (start_idx < num_blocks):
  316. end_idx = min(end_idx, num_blocks)
  317. #global rows
  318. if (self.horizontal_global_attention):
  319. layout[h, start_idx:end_idx, :] = 1
  320. #global columns
  321. first_row = 0 if self.attention == 'bidirectional' else start_idx
  322. layout[h, first_row:, start_idx:end_idx] = 1
  323. return layout
  324. def make_layout(self, seq_len):
  325. """Generates `Variable` sparsity layout used by each head in the sparse attention.
  326. Arguments:
  327. seq_len: required: an integer determining number of attention heads of the layer.
  328. Return:
  329. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` sparsity layout of all head
  330. """
  331. layout = self.setup_layout(seq_len)
  332. for h in range(0, self.num_layout_heads):
  333. layout = self.set_random_layout(h, layout)
  334. layout = self.set_local_layout(h, layout)
  335. layout = self.set_global_layout(h, layout)
  336. layout = self.check_and_propagate_first_head_layout(layout)
  337. return layout
  338. class BigBirdSparsityConfig(SparsityConfig):
  339. """Configuration class to store `BigBird` sparsity configuration.
  340. For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
  341. This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
  342. """
  343. def __init__(self,
  344. num_heads,
  345. block=16,
  346. different_layout_per_head=False,
  347. num_random_blocks=1,
  348. num_sliding_window_blocks=3,
  349. num_global_blocks=1):
  350. """Initialize the BigBird Sparsity Pattern Config.
  351. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
  352. Arguments:
  353. num_heads: required: an integer determining number of attention heads of the layer.
  354. block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
  355. different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
  356. num_random_blocks: optional: an integer determining the number of random blocks in each block row.
  357. num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
  358. num_global_blocks: optional: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well.
  359. """
  360. super().__init__(num_heads, block, different_layout_per_head)
  361. self.num_random_blocks = num_random_blocks
  362. self.num_sliding_window_blocks = num_sliding_window_blocks
  363. self.num_global_blocks = num_global_blocks
  364. def set_random_layout(self, h, layout):
  365. """Sets random attention layout used by the given head in the sparse attention.
  366. Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout.
  367. Arguments:
  368. h: required: an integer determining head index
  369. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  370. Return:
  371. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set
  372. """
  373. num_blocks = layout.shape[1]
  374. if (num_blocks < self.num_random_blocks):
  375. raise ValueError(
  376. f'Number of random blocks, {self.num_random_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!'
  377. )
  378. for row in range(0, num_blocks):
  379. rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
  380. layout[h, row, rnd_cols] = 1
  381. return layout
  382. def set_sliding_window_layout(self, h, layout):
  383. """Sets sliding local attention layout used by the given head in the sparse attention.
  384. Arguments:
  385. h: required: an integer determining head index
  386. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  387. Return:
  388. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
  389. """
  390. num_blocks = layout.shape[1]
  391. if (num_blocks < self.num_sliding_window_blocks):
  392. raise ValueError(
  393. f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!'
  394. )
  395. w = self.num_sliding_window_blocks // 2
  396. for row in range(0, num_blocks):
  397. start = max(0, row - w)
  398. end = min(row + w + 1, num_blocks)
  399. layout[h, row, start:end] = 1
  400. return layout
  401. def set_global_layout_itc(self, h, layout):
  402. """Sets global attention layout used by the given head in the sparse attention.
  403. Arguments:
  404. h: required: an integer determining head index
  405. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  406. Return:
  407. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
  408. """
  409. num_blocks = layout.shape[1]
  410. if (num_blocks < self.num_global_blocks):
  411. raise ValueError(
  412. f'Number of global blocks, {self.num_global_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!'
  413. )
  414. #global rows
  415. layout[h, 0:self.num_global_blocks, :] = 1
  416. #global columns
  417. layout[h, :, 0:self.num_global_blocks] = 1
  418. return layout
  419. def make_layout(self, seq_len):
  420. """Generates `BigBird` sparsity layout used by each head in the sparse attention.
  421. Arguments:
  422. seq_len: required: an integer determining number of attention heads of the layer.
  423. Return:
  424. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` sparsity layout of all head
  425. """
  426. layout = self.setup_layout(seq_len)
  427. for h in range(0, self.num_layout_heads):
  428. layout = self.set_random_layout(h, layout)
  429. layout = self.set_sliding_window_layout(h, layout)
  430. layout = self.set_global_layout_itc(h, layout)
  431. layout = self.check_and_propagate_first_head_layout(layout)
  432. return layout
  433. class BSLongformerSparsityConfig(SparsityConfig):
  434. """Configuration class to store edited `Longformer` sparsity configuration.
  435. Note) this is a block-sparse version of the Longformer which is slightly different than original Longformer; which is element-wise sparsity.
  436. For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
  437. This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
  438. """
  439. def __init__(self,
  440. num_heads,
  441. block=16,
  442. different_layout_per_head=False,
  443. num_sliding_window_blocks=3,
  444. global_block_indices=[0],
  445. global_block_end_indices=None):
  446. """Initialize the edited `Longformer` Sparsity Pattern Config.
  447. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
  448. Arguments:
  449. num_heads: required: an integer determining number of attention heads of the layer.
  450. block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
  451. different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
  452. num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
  453. global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window.
  454. global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention.
  455. """
  456. super().__init__(num_heads, block, different_layout_per_head)
  457. self.num_sliding_window_blocks = num_sliding_window_blocks
  458. self.global_block_indices = global_block_indices
  459. if (global_block_end_indices is not None):
  460. if (len(global_block_indices) != len(global_block_end_indices)):
  461. raise ValueError(
  462. f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!'
  463. )
  464. for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)):
  465. if start_idx >= end_idx:
  466. raise ValueError(
  467. f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!'
  468. )
  469. self.global_block_end_indices = global_block_end_indices
  470. def set_sliding_window_layout(self, h, layout):
  471. """Sets sliding local attention layout used by the given head in the sparse attention.
  472. Arguments:
  473. h: required: an integer determining head index
  474. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  475. Return:
  476. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
  477. """
  478. num_blocks = layout.shape[1]
  479. if (num_blocks < self.num_sliding_window_blocks):
  480. raise ValueError(
  481. f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!'
  482. )
  483. w = self.num_sliding_window_blocks // 2
  484. for row in range(0, num_blocks):
  485. start = max(0, row - w)
  486. end = min(row + w + 1, num_blocks)
  487. layout[h, row, start:end] = 1
  488. return layout
  489. def set_global_layout(self, h, layout):
  490. """Sets global attention layout used by the given head in the sparse attention.
  491. Arguments:
  492. h: required: an integer determining head index
  493. layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step
  494. Return:
  495. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
  496. """
  497. num_blocks = layout.shape[1]
  498. if (self.global_block_end_indices is None):
  499. for idx in self.global_block_indices:
  500. # if global block idx is in the range of the sequence blocks
  501. if (idx < num_blocks):
  502. #global rows
  503. layout[h, idx, :] = 1
  504. #global columns
  505. layout[h, :, idx] = 1
  506. else:
  507. for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)):
  508. # if global block idx is in the range of the sequence blocks
  509. if (start_idx < num_blocks):
  510. end_idx = min(end_idx, num_blocks)
  511. #global rows
  512. layout[h, start_idx:end_idx, :] = 1
  513. #global columns
  514. layout[h, :, start_idx:end_idx] = 1
  515. return layout
  516. def make_layout(self, seq_len):
  517. """Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
  518. Arguments:
  519. seq_len: required: an integer determining number of attention heads of the layer.
  520. Return:
  521. layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` sparsity layout of all head
  522. """
  523. layout = self.setup_layout(seq_len)
  524. for h in range(0, self.num_layout_heads):
  525. layout = self.set_sliding_window_layout(h, layout)
  526. layout = self.set_global_layout(h, layout)
  527. layout = self.check_and_propagate_first_head_layout(layout)
  528. return layout