memory_access_utils.h 30 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. /*
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. */
  4. #pragma once
  5. #include <cuda.h>
  6. #include "ds_kernel_utils.h"
  7. /////////////////////////////// Memory Access Utils ///////////////////////////////
  8. namespace mem_access {
  9. enum class LoadPolicy {
  10. CacheAll, // Cache at all levels
  11. CacheGlobal, // Cache at L2 only
  12. CacheStreaming // Cache with evict first policy
  13. };
  14. enum class StorePolicy {
  15. Writeback, // Cache in L1, write-back on eviction
  16. CacheGlobal, // Bypass L1, write-back on eviction
  17. CacheStreaming // Allocate cache line with evict first policy
  18. };
  19. template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
  20. __device__ __forceinline__ void load_global(void* dst, const void* src);
  21. template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
  22. __device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access);
  23. // Shared accesses have no cache policy
  24. template <int AccessSize>
  25. __device__ __forceinline__ void load_shared(void* dst, const void* src);
  26. template <int AccessSize>
  27. __device__ __forceinline__ void load_shared(void* dst, const void* src, bool do_access);
  28. template <int AccessSize, StorePolicy policy = StorePolicy::Writeback>
  29. __device__ __forceinline__ void store_global(void* dst, const void* src);
  30. // Shared accesses have no cache policy
  31. template <int AccessSize>
  32. __device__ __forceinline__ void store_shared(void* dst, const void* src);
  33. #ifdef ASYNC_COPY_AVAILABLE
  34. template <int AccessSize>
  35. __device__ __forceinline__ void memcpy_async(void* shr, const void* gbl);
  36. template <int AccessSize>
  37. __device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate);
  38. template <int AccessSize>
  39. __device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate);
  40. __device__ __forceinline__ void memcpy_async_fence();
  41. template <int stages>
  42. __device__ __forceinline__ void memcpy_async_wait();
  43. template <int stages>
  44. __device__ __forceinline__ void tail_complete_wait(int remaining_stages);
  45. #endif
  46. // Util for tracking pipeline buffers
  47. // TODO: Evaluate whether this should also be guarded by ASYNC_COPY_AVAILABLE
  48. template <int max>
  49. class BufferTracker {
  50. public:
  51. int current_state;
  52. __device__ __forceinline__ BufferTracker() : current_state(0) {}
  53. __device__ __forceinline__ int get()
  54. {
  55. int return_val = current_state++;
  56. current_state = (current_state == max ? 0 : current_state);
  57. return return_val;
  58. }
  59. };
  60. __device__ __forceinline__ uint32_t lane_id()
  61. {
  62. #ifdef PTX_AVAILABLE
  63. unsigned int lane_id;
  64. asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id));
  65. return lane_id;
  66. #else
  67. return threadIdx.x & (warpSize - 1); // Portable
  68. #endif
  69. }
  70. /////////// Load Global ///////////
  71. template <>
  72. __device__ __forceinline__ void load_global<16>(void* dst, const void* src)
  73. {
  74. uint4* data = reinterpret_cast<uint4*>(dst);
  75. #ifdef PTX_AVAILABLE
  76. asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  77. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  78. : "l"(src));
  79. #else
  80. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  81. data[0] = src_cast[0];
  82. #endif
  83. }
  84. template <>
  85. __device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access)
  86. {
  87. uint4* data = reinterpret_cast<uint4*>(dst);
  88. #ifdef PTX_AVAILABLE
  89. asm volatile(
  90. "{\n"
  91. "\t.reg .pred p;\n"
  92. "\tsetp.ne.b32 p, %5, 0;\n"
  93. "\tmov.b32 %0, 0;\n"
  94. "\tmov.b32 %1, 0;\n"
  95. "\tmov.b32 %2, 0;\n"
  96. "\tmov.b32 %3, 0;\n"
  97. "\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  98. "}\n"
  99. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  100. : "l"(src), "r"((int)do_access));
  101. #else
  102. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  103. if (do_access) {
  104. data[0] = src_cast[0];
  105. } else {
  106. data[0].x = 0;
  107. data[0].y = 0;
  108. data[0].z = 0;
  109. data[0].w = 0;
  110. }
  111. #endif
  112. }
  113. template <>
  114. __device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src)
  115. {
  116. uint4* data = reinterpret_cast<uint4*>(dst);
  117. #ifdef PTX_AVAILABLE
  118. asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  119. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  120. : "l"(src));
  121. #else
  122. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  123. data[0] = src_cast[0];
  124. #endif
  125. }
  126. template <>
  127. __device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst,
  128. const void* src,
  129. bool do_access)
  130. {
  131. uint4* data = reinterpret_cast<uint4*>(dst);
  132. #ifdef PTX_AVAILABLE
  133. asm volatile(
  134. "{\n"
  135. "\t.reg .pred p;\n"
  136. "\tsetp.ne.b32 p, %5, 0;\n"
  137. "\tmov.b32 %0, 0;\n"
  138. "\tmov.b32 %1, 0;\n"
  139. "\tmov.b32 %2, 0;\n"
  140. "\tmov.b32 %3, 0;\n"
  141. "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  142. "}\n"
  143. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  144. : "l"(src), "r"((int)do_access));
  145. #else
  146. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  147. if (do_access) {
  148. data[0] = src_cast[0];
  149. } else {
  150. data[0].x = 0;
  151. data[0].y = 0;
  152. data[0].z = 0;
  153. data[0].w = 0;
  154. }
  155. #endif
  156. }
  157. template <>
  158. __device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst,
  159. const void* src)
  160. {
  161. uint4* data = reinterpret_cast<uint4*>(dst);
  162. #ifdef PTX_AVAILABLE
  163. asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  164. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  165. : "l"(src));
  166. #else
  167. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  168. data[0] = src_cast[0];
  169. #endif
  170. }
  171. template <>
  172. __device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst,
  173. const void* src,
  174. bool do_access)
  175. {
  176. uint4* data = reinterpret_cast<uint4*>(dst);
  177. #ifdef PTX_AVAILABLE
  178. asm volatile(
  179. "{\n"
  180. "\t.reg .pred p;\n"
  181. "\tsetp.ne.b32 p, %5, 0;\n"
  182. "\tmov.b32 %0, 0;\n"
  183. "\tmov.b32 %1, 0;\n"
  184. "\tmov.b32 %2, 0;\n"
  185. "\tmov.b32 %3, 0;\n"
  186. "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  187. "}\n"
  188. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  189. : "l"(src), "r"((int)do_access));
  190. #else
  191. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  192. if (do_access) {
  193. data[0] = src_cast[0];
  194. } else {
  195. data[0].x = 0;
  196. data[0].y = 0;
  197. data[0].z = 0;
  198. data[0].w = 0;
  199. }
  200. #endif
  201. }
  202. template <>
  203. __device__ __forceinline__ void load_global<8>(void* dst, const void* src)
  204. {
  205. uint2* data = reinterpret_cast<uint2*>(dst);
  206. #ifdef PTX_AVAILABLE
  207. asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];\n"
  208. : "=r"(data[0].x), "=r"(data[0].y)
  209. : "l"(src));
  210. #else
  211. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  212. data[0] = src_cast[0];
  213. #endif
  214. }
  215. template <>
  216. __device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access)
  217. {
  218. uint2* data = reinterpret_cast<uint2*>(dst);
  219. #ifdef PTX_AVAILABLE
  220. asm volatile(
  221. "{\n"
  222. "\t.reg .pred p;\n"
  223. "\tsetp.ne.b32 p, %3, 0;\n"
  224. "\tmov.b32 %0, 0;\n"
  225. "\tmov.b32 %1, 0;\n"
  226. "\t@p ld.global.v2.u32 {%0, %1}, [%2];\n"
  227. "}\n"
  228. : "=r"(data[0].x), "=r"(data[0].y)
  229. : "l"(src), "r"((int)do_access));
  230. #else
  231. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  232. if (do_access) {
  233. data[0] = src_cast[0];
  234. } else {
  235. data[0].x = 0;
  236. data[0].y = 0;
  237. }
  238. #endif
  239. }
  240. template <>
  241. __device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src)
  242. {
  243. uint2* data = reinterpret_cast<uint2*>(dst);
  244. #ifdef PTX_AVAILABLE
  245. asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];\n"
  246. : "=r"(data[0].x), "=r"(data[0].y)
  247. : "l"(src));
  248. #else
  249. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  250. data[0] = src_cast[0];
  251. #endif
  252. }
  253. template <>
  254. __device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst,
  255. const void* src,
  256. bool do_access)
  257. {
  258. uint2* data = reinterpret_cast<uint2*>(dst);
  259. #ifdef PTX_AVAILABLE
  260. asm volatile(
  261. "{\n"
  262. "\t.reg .pred p;\n"
  263. "\tsetp.ne.b32 p, %3, 0;\n"
  264. "\tmov.b32 %0, 0;\n"
  265. "\tmov.b32 %1, 0;\n"
  266. "\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n"
  267. "}\n"
  268. : "=r"(data[0].x), "=r"(data[0].y)
  269. : "l"(src), "r"((int)do_access));
  270. #else
  271. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  272. if (do_access) {
  273. data[0] = src_cast[0];
  274. } else {
  275. data[0].x = 0;
  276. data[0].y = 0;
  277. }
  278. #endif
  279. }
  280. template <>
  281. __device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst,
  282. const void* src)
  283. {
  284. uint2* data = reinterpret_cast<uint2*>(dst);
  285. #ifdef PTX_AVAILABLE
  286. asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
  287. : "=r"(data[0].x), "=r"(data[0].y)
  288. : "l"(src));
  289. #else
  290. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  291. data[0] = src_cast[0];
  292. #endif
  293. }
  294. template <>
  295. __device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst,
  296. const void* src,
  297. bool do_access)
  298. {
  299. uint2* data = reinterpret_cast<uint2*>(dst);
  300. #ifdef PTX_AVAILABLE
  301. asm volatile(
  302. "{\n"
  303. "\t.reg .pred p;\n"
  304. "\tsetp.ne.b32 p, %3, 0;\n"
  305. "\tmov.b32 %0, 0;\n"
  306. "\tmov.b32 %1, 0;\n"
  307. "\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
  308. "}\n"
  309. : "=r"(data[0].x), "=r"(data[0].y)
  310. : "l"(src), "r"((int)do_access));
  311. #else
  312. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  313. if (do_access) {
  314. data[0] = src_cast[0];
  315. } else {
  316. data[0].x = 0;
  317. data[0].y = 0;
  318. }
  319. #endif
  320. }
  321. template <>
  322. __device__ __forceinline__ void load_global<4>(void* dst, const void* src)
  323. {
  324. int32_t* data = reinterpret_cast<int32_t*>(dst);
  325. #ifdef PTX_AVAILABLE
  326. asm volatile("ld.global.ca.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src));
  327. #else
  328. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  329. data[0] = src_cast[0];
  330. #endif
  331. }
  332. template <>
  333. __device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access)
  334. {
  335. int32_t* data = reinterpret_cast<int32_t*>(dst);
  336. #ifdef PTX_AVAILABLE
  337. asm volatile(
  338. "{\n"
  339. "\t.reg .pred p;\n"
  340. "\tsetp.ne.b32 p, %2, 0;\n"
  341. "\tmov.b32 %0, 0;\n"
  342. "\t@p ld.global.u32 {%0}, [%1];\n"
  343. "}\n"
  344. : "=r"(data[0])
  345. : "l"(src), "r"((int)do_access));
  346. #else
  347. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  348. if (do_access) {
  349. data[0] = src_cast[0];
  350. } else {
  351. data[0] = 0;
  352. }
  353. #endif
  354. }
  355. template <>
  356. __device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src)
  357. {
  358. int32_t* data = reinterpret_cast<int32_t*>(dst);
  359. #ifdef PTX_AVAILABLE
  360. asm volatile("ld.global.cg.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src));
  361. #else
  362. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  363. data[0] = src_cast[0];
  364. #endif
  365. }
  366. template <>
  367. __device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst,
  368. const void* src,
  369. bool do_access)
  370. {
  371. int32_t* data = reinterpret_cast<int32_t*>(dst);
  372. #ifdef PTX_AVAILABLE
  373. asm volatile(
  374. "{\n"
  375. "\t.reg .pred p;\n"
  376. "\tsetp.ne.b32 p, %2, 0;\n"
  377. "\tmov.b32 %0, 0;\n"
  378. "\t@p ld.global.cg.u32 {%0}, [%1];\n"
  379. "}\n"
  380. : "=r"(data[0])
  381. : "l"(src), "r"((int)do_access));
  382. #else
  383. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  384. if (do_access) {
  385. data[0] = src_cast[0];
  386. } else {
  387. data[0] = 0;
  388. }
  389. #endif
  390. }
  391. template <>
  392. __device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst,
  393. const void* src)
  394. {
  395. int32_t* data = reinterpret_cast<int32_t*>(dst);
  396. #ifdef PTX_AVAILABLE
  397. asm volatile("ld.global.cs.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src));
  398. #else
  399. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  400. data[0] = src_cast[0];
  401. #endif
  402. }
  403. template <>
  404. __device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst,
  405. const void* src,
  406. bool do_access)
  407. {
  408. int32_t* data = reinterpret_cast<int32_t*>(dst);
  409. #ifdef PTX_AVAILABLE
  410. asm volatile(
  411. "{\n"
  412. "\t.reg .pred p;\n"
  413. "\tsetp.ne.b32 p, %2, 0;\n"
  414. "\tmov.b32 %0, 0;\n"
  415. "\t@p ld.global.cs.u32 {%0}, [%1];\n"
  416. "}\n"
  417. : "=r"(data[0])
  418. : "l"(src), "r"((int)do_access));
  419. #else
  420. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  421. if (do_access) {
  422. data[0] = src_cast[0];
  423. } else {
  424. data[0] = 0;
  425. }
  426. #endif
  427. }
  428. /////////// Load Shared ///////////
  429. namespace internal {
  430. #ifdef PTX_AVAILABLE
  431. __device__ __forceinline__ unsigned convert_to_shared(const void* ptr)
  432. {
  433. #if __CUDACC_VER_MAJOR__ >= 11
  434. // In CUDA 11 we have a builtin intrinsic
  435. return __cvta_generic_to_shared(ptr);
  436. #else
  437. unsigned ret_val;
  438. asm volatile(
  439. "{\n"
  440. "\t.reg .u64 p1;\n"
  441. "\tcvta.to.shared.u64 p1, %1\n"
  442. "\tcvt.u32.u64 %0, p1;\n"
  443. "}\n"
  444. : "=r"(ret_val)
  445. : "l"(ptr));
  446. return ret_val;
  447. #endif
  448. }
  449. #endif
  450. } // namespace internal
  451. template <>
  452. __device__ __forceinline__ void load_shared<16>(void* dst, const void* src)
  453. {
  454. uint4* data = reinterpret_cast<uint4*>(dst);
  455. #ifdef PTX_AVAILABLE
  456. unsigned src_shr = internal::convert_to_shared(src);
  457. asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  458. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  459. : "r"(src_shr));
  460. #else
  461. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  462. data[0] = src_cast[0];
  463. #endif
  464. }
  465. template <>
  466. __device__ __forceinline__ void load_shared<16>(void* dst, const void* src, bool do_access)
  467. {
  468. uint4* data = reinterpret_cast<uint4*>(dst);
  469. #ifdef PTX_AVAILABLE
  470. unsigned src_shr = internal::convert_to_shared(src);
  471. asm volatile(
  472. "{\n"
  473. "\t.reg .pred p;\n"
  474. "\tsetp.ne.b32 p, %5, 0;\n"
  475. "\tmov.b32 %0, 0;\n"
  476. "\tmov.b32 %1, 0;\n"
  477. "\tmov.b32 %2, 0;\n"
  478. "\tmov.b32 %3, 0;\n"
  479. "\t@p ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n"
  480. "}\n"
  481. : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
  482. : "r"(src_shr), "r"((int)do_access));
  483. #else
  484. const uint4* src_cast = reinterpret_cast<const uint4*>(src);
  485. if (do_access) {
  486. data[0] = src_cast[0];
  487. } else {
  488. data[0].x = 0;
  489. data[0].y = 0;
  490. data[0].z = 0;
  491. data[0].w = 0;
  492. }
  493. #endif
  494. }
  495. template <>
  496. __device__ __forceinline__ void load_shared<8>(void* dst, const void* src)
  497. {
  498. uint2* data = reinterpret_cast<uint2*>(dst);
  499. #ifdef PTX_AVAILABLE
  500. unsigned src_shr = internal::convert_to_shared(src);
  501. asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n"
  502. : "=r"(data[0].x), "=r"(data[0].y)
  503. : "r"(src_shr));
  504. #else
  505. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  506. data[0] = src_cast[0];
  507. #endif
  508. }
  509. template <>
  510. __device__ __forceinline__ void load_shared<8>(void* dst, const void* src, bool do_access)
  511. {
  512. uint2* data = reinterpret_cast<uint2*>(dst);
  513. #ifdef PTX_AVAILABLE
  514. unsigned src_shr = internal::convert_to_shared(src);
  515. asm volatile(
  516. "{\n"
  517. "\t.reg .pred p;\n"
  518. "\tsetp.ne.b32 p, %3, 0;\n"
  519. "\tmov.b32 %0, 0;\n"
  520. "\tmov.b32 %1, 0;\n"
  521. "\t@p ld.shared.v2.u32 {%0, %1}, [%2];\n"
  522. "}\n"
  523. : "=r"(data[0].x), "=r"(data[0].y)
  524. : "r"(src_shr), "r"((int)do_access));
  525. #else
  526. const uint2* src_cast = reinterpret_cast<const uint2*>(src);
  527. if (do_access) {
  528. data[0] = src_cast[0];
  529. } else {
  530. data[0].x = 0;
  531. data[0].y = 0;
  532. }
  533. #endif
  534. }
  535. template <>
  536. __device__ __forceinline__ void load_shared<4>(void* dst, const void* src)
  537. {
  538. int32_t* data = reinterpret_cast<int32_t*>(dst);
  539. #ifdef PTX_AVAILABLE
  540. unsigned src_shr = internal::convert_to_shared(src);
  541. asm volatile("ld.shared.u32 {%0}, [%1];\n" : "=r"(*data) : "r"(src_shr));
  542. #else
  543. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  544. data[0] = src_cast[0];
  545. #endif
  546. }
  547. template <>
  548. __device__ __forceinline__ void load_shared<4>(void* dst, const void* src, bool do_access)
  549. {
  550. int32_t* data = reinterpret_cast<int32_t*>(dst);
  551. #ifdef PTX_AVAILABLE
  552. unsigned src_shr = internal::convert_to_shared(src);
  553. asm volatile(
  554. "{\n"
  555. "\t.reg .pred p;\n"
  556. "\tsetp.ne.b32 p, %2, 0;\n"
  557. "\tmov.b32 %0, 0;\n"
  558. "\t@p ld.shared.u32 %0, [%1];\n"
  559. "}\n"
  560. : "=r"(data[0])
  561. : "r"(src_shr), "r"((int)do_access));
  562. #else
  563. const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
  564. if (do_access) {
  565. data[0] = src_cast[0];
  566. } else {
  567. data[0] = 0;
  568. }
  569. #endif
  570. }
  571. /////////// Store Global ///////////
  572. template <>
  573. __device__ __forceinline__ void store_global<16>(void* dst, const void* src)
  574. {
  575. const uint4* data = reinterpret_cast<const uint4*>(src);
  576. #ifdef PTX_AVAILABLE
  577. asm volatile("st.global.wb.v4.u32 [%0], {%1, %2, %3, %4};\n"
  578. :
  579. : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)
  580. : "memory");
  581. #else
  582. uint4* dst_cast = reinterpret_cast<uint4*>(dst);
  583. dst_cast[0] = data[0];
  584. #endif
  585. }
  586. template <>
  587. __device__ __forceinline__ void store_global<16, StorePolicy::CacheGlobal>(void* dst,
  588. const void* src)
  589. {
  590. const uint4* data = reinterpret_cast<const uint4*>(src);
  591. #ifdef PTX_AVAILABLE
  592. asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};\n"
  593. :
  594. : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)
  595. : "memory");
  596. #else
  597. uint4* dst_cast = reinterpret_cast<uint4*>(dst);
  598. dst_cast[0] = data[0];
  599. #endif
  600. }
  601. template <>
  602. __device__ __forceinline__ void store_global<16, StorePolicy::CacheStreaming>(void* dst,
  603. const void* src)
  604. {
  605. const uint4* data = reinterpret_cast<const uint4*>(src);
  606. #ifdef PTX_AVAILABLE
  607. asm volatile("st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n"
  608. :
  609. : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)
  610. : "memory");
  611. #else
  612. uint4* dst_cast = reinterpret_cast<uint4*>(dst);
  613. dst_cast[0] = data[0];
  614. #endif
  615. }
  616. template <>
  617. __device__ __forceinline__ void store_global<8>(void* dst, const void* src)
  618. {
  619. const uint2* data = reinterpret_cast<const uint2*>(src);
  620. #ifdef PTX_AVAILABLE
  621. asm volatile("st.global.wb.v2.u32 [%0], {%1, %2};\n"
  622. :
  623. : "l"(dst), "r"(data[0].x), "r"(data[0].y));
  624. #else
  625. uint2* dst_cast = reinterpret_cast<uint2*>(dst);
  626. dst_cast[0] = data[0];
  627. #endif
  628. }
  629. template <>
  630. __device__ __forceinline__ void store_global<8, StorePolicy::CacheGlobal>(void* dst,
  631. const void* src)
  632. {
  633. const uint2* data = reinterpret_cast<const uint2*>(src);
  634. #ifdef PTX_AVAILABLE
  635. asm volatile("st.global.cg.v2.u32 [%0], {%1, %2};\n"
  636. :
  637. : "l"(dst), "r"(data[0].x), "r"(data[0].y));
  638. #else
  639. uint2* dst_cast = reinterpret_cast<uint2*>(dst);
  640. dst_cast[0] = data[0];
  641. #endif
  642. }
  643. template <>
  644. __device__ __forceinline__ void store_global<8, StorePolicy::CacheStreaming>(void* dst,
  645. const void* src)
  646. {
  647. const uint2* data = reinterpret_cast<const uint2*>(src);
  648. #ifdef PTX_AVAILABLE
  649. asm volatile("st.global.cs.v2.u32 [%0], {%1, %2};\n"
  650. :
  651. : "l"(dst), "r"(data[0].x), "r"(data[0].y));
  652. #else
  653. uint2* dst_cast = reinterpret_cast<uint2*>(dst);
  654. dst_cast[0] = data[0];
  655. #endif
  656. }
  657. template <>
  658. __device__ __forceinline__ void store_global<4>(void* dst, const void* src)
  659. {
  660. const int32_t* data = reinterpret_cast<const int32_t*>(src);
  661. #ifdef PTX_AVAILABLE
  662. asm volatile("st.global.wb.u32 [%0], %1;\n" : : "l"(dst), "r"(*data));
  663. #else
  664. int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
  665. dst_cast[0] = data[0];
  666. #endif
  667. }
  668. template <>
  669. __device__ __forceinline__ void store_global<4, StorePolicy::CacheGlobal>(void* dst,
  670. const void* src)
  671. {
  672. const int32_t* data = reinterpret_cast<const int32_t*>(src);
  673. #ifdef PTX_AVAILABLE
  674. asm volatile("st.global.cg.u32 [%0], %1;\n" : : "l"(dst), "r"(*data));
  675. #else
  676. int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
  677. dst_cast[0] = data[0];
  678. #endif
  679. }
  680. template <>
  681. __device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(void* dst,
  682. const void* src)
  683. {
  684. const int32_t* data = reinterpret_cast<const int32_t*>(src);
  685. #ifdef PTX_AVAILABLE
  686. asm volatile("st.global.cs.u32 [%0], %1;\n" : : "l"(dst), "r"(*data));
  687. #else
  688. int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
  689. dst_cast[0] = data[0];
  690. #endif
  691. }
  692. /////////// Store Shared ///////////
  693. template <>
  694. __device__ __forceinline__ void store_shared<16>(void* dst, const void* src)
  695. {
  696. const uint4* data = reinterpret_cast<const uint4*>(src);
  697. #ifdef PTX_AVAILABLE
  698. unsigned dst_int = internal::convert_to_shared(dst);
  699. asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
  700. :
  701. : "r"(dst_int), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w));
  702. #else
  703. uint4* dst_cast = reinterpret_cast<uint4*>(dst);
  704. dst_cast[0] = data[0];
  705. #endif
  706. }
  707. template <>
  708. __device__ __forceinline__ void store_shared<8>(void* dst, const void* src)
  709. {
  710. const uint2* data = reinterpret_cast<const uint2*>(src);
  711. #ifdef PTX_AVAILABLE
  712. unsigned dst_int = internal::convert_to_shared(dst);
  713. asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n"
  714. :
  715. : "r"(dst_int), "r"(data[0].x), "r"(data[0].y));
  716. #else
  717. uint2* dst_cast = reinterpret_cast<uint2*>(dst);
  718. dst_cast[0] = data[0];
  719. #endif
  720. }
  721. template <>
  722. __device__ __forceinline__ void store_shared<4>(void* dst, const void* src)
  723. {
  724. const int32_t* data = reinterpret_cast<const int32_t*>(src);
  725. #ifdef PTX_AVAILABLE
  726. unsigned dst_int = internal::convert_to_shared(dst);
  727. asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(dst_int), "r"(*data));
  728. #else
  729. int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
  730. dst_cast[0] = data[0];
  731. #endif
  732. }
  733. /////////// Asynchronous Memory Copy ///////////
  734. #ifdef ASYNC_COPY_AVAILABLE
  735. template <int AccessSize>
  736. __device__ __forceinline__ void memcpy_async(void* shr, const void* gbl)
  737. {
  738. static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
  739. unsigned shr_int = internal::convert_to_shared(shr);
  740. asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n"
  741. :
  742. : "r"(shr_int), "l"(gbl), "n"(AccessSize));
  743. }
  744. template <int AccessSize>
  745. __device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate)
  746. {
  747. static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
  748. unsigned shr_int = internal::convert_to_shared(shr);
  749. asm volatile(
  750. "{\n"
  751. " .reg .pred p;\n"
  752. " setp.ne.b32 p, %0, 0;\n"
  753. " @p cp.async.ca.shared.global [%1], [%2], %3;\n"
  754. "}\n"
  755. :
  756. : "r"((int)predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize));
  757. }
  758. template <int AccessSize>
  759. __device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate)
  760. {
  761. static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
  762. unsigned shr_int = internal::convert_to_shared(shr);
  763. int bytes_to_copy = (predicate ? AccessSize : 0);
  764. asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n"
  765. :
  766. : "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy));
  767. }
  768. template <int AccessSize>
  769. __device__ __forceinline__ void memcpy_async_zero_nop(void* shr,
  770. const void* gbl,
  771. bool zero_predicate,
  772. bool nop_predicate)
  773. {
  774. static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
  775. unsigned shr_int = internal::convert_to_shared(shr);
  776. int bytes_to_copy = (zero_predicate ? AccessSize : 0);
  777. asm volatile(
  778. "{\n"
  779. " .reg .pred p;\n"
  780. " setp.ne.b32 p, %0, 0;\n"
  781. " @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n"
  782. "}\n"
  783. :
  784. : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy));
  785. }
  786. // Cache global variants. Separate interface to require deliberate use of them.
  787. __device__ __forceinline__ void memcpy_async_cg(void* shr, const void* gbl)
  788. {
  789. unsigned shr_int = internal::convert_to_shared(shr);
  790. asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" : : "r"(shr_int), "l"(gbl));
  791. }
  792. __device__ __forceinline__ void memcpy_async_nop_cg(void* shr, const void* gbl, bool predicate)
  793. {
  794. unsigned shr_int = internal::convert_to_shared(shr);
  795. asm volatile(
  796. "{\n"
  797. " .reg .pred p;\n"
  798. " setp.ne.b32 p, %0, 0;\n"
  799. " @p cp.async.cg.shared.global [%1], [%2], 16;\n"
  800. "}\n"
  801. :
  802. : "r"((int)predicate), "r"(shr_int), "l"(gbl));
  803. }
  804. __device__ __forceinline__ void memcpy_async_zero_cg(void* shr, const void* gbl, bool predicate)
  805. {
  806. unsigned shr_int = internal::convert_to_shared(shr);
  807. int bytes_to_copy = (predicate ? 16 : 0);
  808. asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n"
  809. :
  810. : "r"(shr_int), "l"(gbl), "r"(bytes_to_copy));
  811. }
  812. __device__ __forceinline__ void memcpy_async_zero_nop_cg(void* shr,
  813. const void* gbl,
  814. bool zero_predicate,
  815. bool nop_predicate)
  816. {
  817. unsigned shr_int = internal::convert_to_shared(shr);
  818. int bytes_to_copy = (zero_predicate ? 16 : 0);
  819. asm volatile(
  820. "{\n"
  821. " .reg .pred p;\n"
  822. " setp.ne.b32 p, %0, 0;\n"
  823. " @p cp.async.cg.shared.global [%1], [%2], 16, %3;\n"
  824. "}\n"
  825. :
  826. : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "r"(bytes_to_copy));
  827. }
  828. __device__ __forceinline__ void memcpy_async_fence() { asm volatile("cp.async.commit_group;\n"); }
  829. template <int stages>
  830. __device__ __forceinline__ void memcpy_async_wait()
  831. {
  832. static_assert(stages <= 8);
  833. asm volatile("cp.async.wait_group %0;\n" : : "n"(stages));
  834. }
  835. // TODO: The tail complete should be a known compile time artifact, should try and induce this
  836. // without all of the branches from the call-site. This is a hacky solution.
  837. template <>
  838. __device__ __forceinline__ void tail_complete_wait<1>(int remaining_stages)
  839. {
  840. if (remaining_stages == 0) memcpy_async_wait<0>();
  841. }
  842. template <>
  843. __device__ __forceinline__ void tail_complete_wait<2>(int remaining_stages)
  844. {
  845. if (remaining_stages == 1)
  846. memcpy_async_wait<1>();
  847. else if (remaining_stages == 0)
  848. memcpy_async_wait<0>();
  849. }
  850. template <>
  851. __device__ __forceinline__ void tail_complete_wait<3>(int remaining_stages)
  852. {
  853. if (remaining_stages == 2)
  854. memcpy_async_wait<2>();
  855. else if (remaining_stages == 1)
  856. memcpy_async_wait<1>();
  857. else if (remaining_stages == 0)
  858. memcpy_async_wait<0>();
  859. }
  860. template <>
  861. __device__ __forceinline__ void tail_complete_wait<4>(int remaining_stages)
  862. {
  863. if (remaining_stages == 3)
  864. memcpy_async_wait<3>();
  865. else if (remaining_stages == 2)
  866. memcpy_async_wait<2>();
  867. else if (remaining_stages == 1)
  868. memcpy_async_wait<1>();
  869. else if (remaining_stages == 0)
  870. memcpy_async_wait<0>();
  871. }
  872. template <>
  873. __device__ __forceinline__ void tail_complete_wait<5>(int remaining_stages)
  874. {
  875. if (remaining_stages == 4)
  876. memcpy_async_wait<4>();
  877. else if (remaining_stages == 3)
  878. memcpy_async_wait<3>();
  879. else if (remaining_stages == 2)
  880. memcpy_async_wait<2>();
  881. else if (remaining_stages == 1)
  882. memcpy_async_wait<1>();
  883. else if (remaining_stages == 0)
  884. memcpy_async_wait<0>();
  885. }
  886. template <>
  887. __device__ __forceinline__ void tail_complete_wait<6>(int remaining_stages)
  888. {
  889. if (remaining_stages == 5)
  890. memcpy_async_wait<5>();
  891. else if (remaining_stages == 4)
  892. memcpy_async_wait<4>();
  893. else if (remaining_stages == 3)
  894. memcpy_async_wait<3>();
  895. else if (remaining_stages == 2)
  896. memcpy_async_wait<2>();
  897. else if (remaining_stages == 1)
  898. memcpy_async_wait<1>();
  899. else if (remaining_stages == 0)
  900. memcpy_async_wait<0>();
  901. }
  902. #endif
  903. } // namespace mem_access