1  #ifndef AMX_HELPER_H_INCLUDED
       2  #define AMX_HELPER_H_INCLUDED
       3  #if defined(AMX_FP16) || defined(AMX_COMPLEX)
       4  #include <immintrin.h>
       5  #include <xmmintrin.h>
       6  #endif
       7  #include "amx-check.h"
       8  
       9  typedef union
      10  {
      11    _Float16 f16;
      12    uint16_t u;
      13  } union16f_uw;
      14  
      15  #if defined(AMX_FP16) || defined(AMX_COMPLEX)
      16  /* Transformation functions between fp16/float */
      17  static uint16_t make_f32_fp16 (float f)
      18  {
      19    union16f_uw tmp;
      20    __m128 b = _mm_set_ss (f);
      21    __m128h a;
      22    tmp.f16 = _mm_cvtsh_h (_mm_cvtss_sh (a, b));
      23    return tmp.u;
      24  }
      25  
      26  static float make_fp16_f32 (uint16_t fp)
      27  {
      28    union16f_uw tmp;
      29    tmp.u = fp;
      30    __m128h b = _mm_set_sh (tmp.f16);
      31    __m128 a;
      32    return _mm_cvtss_f32 (_mm_cvtsh_ss (a, b));
      33  }
      34  
      35  /* Init tile buffer with fp16 pairs */
      36  void init_fp16_max_tile_buffer (uint8_t* buf)
      37  {
      38    int i, j;
      39    uint16_t* ptr = (uint16_t *) buf;
      40  
      41    for (i = 0; i < 16; i++)
      42      for (j = 0; j < 32; j++)
      43      {
      44        float f = 2.5f * i + 1.25f * j;
      45        ptr[i * 32 + j] = make_f32_fp16 (f);
      46      }
      47  }
      48  
      49  /* Init tile fp16 pair buffer with zero */
      50  void init_fp16_max_tile_zero_buffer (uint8_t* buf)
      51  {
      52    int i, j;
      53    uint16_t* ptr = (uint16_t *) buf;
      54  
      55    for (i = 0; i < 16; i++)
      56      for (j = 0; j < 32; j++)
      57        ptr[i * 32 + j] = make_f32_fp16 (0.0f);
      58  }
      59  #endif
      60  
      61  #endif