simd.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #pragma once
  2. #if (__x86_64__ || __i386__)
  3. #include <cpuid.h>
  4. #include <x86intrin.h>
  5. #endif
  6. #define TILE (128 * 1024 * 1024)
  7. #if defined(__AVX512__) or defined(__AVX256__)
  8. #define ROUND_DOWN(size, step) ((size) & ~((step)-1))
  9. #if defined(__AVX512__)
  10. #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
  11. #define SIMD_LOAD(x) _mm512_loadu_ps(x)
  12. #define SIMD_SET(x) _mm512_set1_ps(x)
  13. #define SIMD_ADD(x, y) _mm512_add_ps(x, y)
  14. #define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
  15. #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
  16. #define SIMD_SQRT(x) _mm512_sqrt_ps(x)
  17. #define SIMD_DIV(x, y) _mm512_div_ps(x, y)
  18. #define SIMD_WIDTH 16
  19. #define SIMD_LOAD2(x, h) ((h) ? _mm512_cvtph_ps(_mm256_loadu_ps(x)) : _mm512_loadu_ps(x))
  20. #define SIMD_STORE2(x, d, h) \
  21. ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
  22. : _mm512_storeu_ps(x, d))
  23. #define INTV __m256i
  24. #elif defined(__AVX256__)
  25. #define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
  26. #define SIMD_LOAD(x) _mm256_loadu_ps(x)
  27. #define SIMD_SET(x) _mm256_set1_ps(x)
  28. #define SIMD_ADD(x, y) _mm256_add_ps(x, y)
  29. #define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
  30. #define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
  31. #define SIMD_SQRT(x) _mm256_sqrt_ps(x)
  32. #define SIMD_DIV(x, y) _mm256_div_ps(x, y)
  33. #define SIMD_WIDTH 8
  34. #define SIMD_LOAD2(x, h) \
  35. ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))
  36. #define SIMD_STORE2(x, d, h) \
  37. ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
  38. : _mm256_storeu_ps(x, d))
  39. #define INTV __m128i
  40. #endif
  41. union AVX_Data {
  42. #if defined(__AVX512__)
  43. __m512 data;
  44. #elif defined(__AVX256__)
  45. __m256 data;
  46. #endif
  47. // float data_f[16];
  48. };
  49. template <int span>
  50. inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
  51. {
  52. size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
  53. #pragma unroll
  54. for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); }
  55. }
  56. template <int span>
  57. inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
  58. {
  59. size_t width = (half_precision ? 1 : SIMD_WIDTH);
  60. #pragma unroll
  61. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); }
  62. }
  63. template <int span>
  64. inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
  65. {
  66. #pragma unroll
  67. for (size_t i = 0; i < span; ++i) {
  68. dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data);
  69. }
  70. }
  71. template <int span>
  72. inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a)
  73. {
  74. #pragma unroll
  75. for (size_t i = 0; i < span; ++i) {
  76. dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data);
  77. }
  78. }
  79. template <int span>
  80. inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a)
  81. {
  82. #pragma unroll
  83. for (size_t i = 0; i < span; ++i) {
  84. dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data);
  85. }
  86. }
  87. template <int span>
  88. inline void simd_sqrt(AVX_Data* dst, AVX_Data* src)
  89. {
  90. #pragma unroll
  91. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); }
  92. }
  93. template <int span>
  94. inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
  95. {
  96. #pragma unroll
  97. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); }
  98. }
  99. template <int span>
  100. inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
  101. {
  102. #pragma unroll
  103. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); }
  104. }
  105. template <int span>
  106. inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
  107. {
  108. #pragma unroll
  109. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); }
  110. }
  111. template <int span>
  112. inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
  113. {
  114. #pragma unroll
  115. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); }
  116. }
  117. template <int span>
  118. inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
  119. {
  120. #pragma unroll
  121. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); }
  122. }
  123. #endif