simd.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #if (__x86_64__ || __i386__)
  3. #include <cpuid.h>
  4. #include <x86intrin.h>
  5. #endif
  6. #define TILE (128 * 1024 * 1024)
  7. #if defined(__AVX512__) or defined(__AVX256__)
  8. #define ROUND_DOWN(size, step) ((size) & ~((step)-1))
  9. #if defined(__AVX512__)
  10. #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
  11. #define SIMD_LOAD(x) _mm512_loadu_ps(x)
  12. #define SIMD_SET(x) _mm512_set1_ps(x)
  13. #define SIMD_ADD(x, y) _mm512_add_ps(x, y)
  14. #define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
  15. #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
  16. #define SIMD_SQRT(x) _mm512_sqrt_ps(x)
  17. #define SIMD_DIV(x, y) _mm512_div_ps(x, y)
  18. #define SIMD_WIDTH 16
  19. #define SIMD_LOAD2(x, h) \
  20. ((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x))
  21. #define SIMD_STORE2(x, d, h) \
  22. ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
  23. : _mm512_storeu_ps(x, d))
  24. #define INTV __m256i
  25. #elif defined(__AVX256__)
  26. #define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
  27. #define SIMD_LOAD(x) _mm256_loadu_ps(x)
  28. #define SIMD_SET(x) _mm256_set1_ps(x)
  29. #define SIMD_ADD(x, y) _mm256_add_ps(x, y)
  30. #define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
  31. #define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
  32. #define SIMD_SQRT(x) _mm256_sqrt_ps(x)
  33. #define SIMD_DIV(x, y) _mm256_div_ps(x, y)
  34. #define SIMD_WIDTH 8
  35. #define SIMD_LOAD2(x, h) \
  36. ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))
  37. #define SIMD_STORE2(x, d, h) \
  38. ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
  39. : _mm256_storeu_ps(x, d))
  40. #define INTV __m128i
  41. #endif
  42. union AVX_Data {
  43. #if defined(__AVX512__)
  44. __m512 data;
  45. #elif defined(__AVX256__)
  46. __m256 data;
  47. #endif
  48. // float data_f[16];
  49. };
  50. template <int span>
  51. inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
  52. {
  53. #pragma omp parallel for
  54. for (size_t i = 0; i < span; ++i) {
  55. SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision);
  56. }
  57. }
  58. template <int span>
  59. inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
  60. {
  61. #pragma omp parallel for
  62. for (size_t i = 0; i < span; ++i) {
  63. dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision);
  64. }
  65. }
  66. template <int span>
  67. inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
  68. {
  69. #pragma omp parallel for
  70. for (size_t i = 0; i < span; ++i) {
  71. dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data);
  72. }
  73. }
  74. template <int span>
  75. inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a)
  76. {
  77. #pragma omp parallel for
  78. for (size_t i = 0; i < span; ++i) {
  79. dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data);
  80. }
  81. }
  82. template <int span>
  83. inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a)
  84. {
  85. #pragma omp parallel for
  86. for (size_t i = 0; i < span; ++i) {
  87. dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data);
  88. }
  89. }
  90. template <int span>
  91. inline void simd_sqrt(AVX_Data* dst, AVX_Data* src)
  92. {
  93. #pragma omp parallel for
  94. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); }
  95. }
  96. template <int span>
  97. inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
  98. {
  99. #pragma omp parallel for
  100. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); }
  101. }
  102. template <int span>
  103. inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
  104. {
  105. #pragma omp parallel for
  106. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); }
  107. }
  108. template <int span>
  109. inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
  110. {
  111. #pragma omp parallel for
  112. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); }
  113. }
  114. template <int span>
  115. inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
  116. {
  117. #pragma omp parallel for
  118. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); }
  119. }
  120. template <int span>
  121. inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
  122. {
  123. #pragma omp parallel for
  124. for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); }
  125. }
  126. #endif