(root)/
gcc-13.2.0/
gcc/
testsuite/
gcc.target/
i386/
amxbf16-dpbf16ps-2.c
       1  /* { dg-do run { target { ! ia32 } } } */
       2  /* { dg-require-effective-target amx_bf16 } */
       3  /* { dg-options "-O2 -mamx-bf16" } */
       4  #include <immintrin.h>
       5  
       6  #define AMX_BF16
       7  #define DO_TEST test_amx_bf16_dpbf16ps
       8  void test_amx_bf16_dpbf16ps ();
       9  #include "amx-check.h"
      10  
      11  /* Transformation functions between bf16/float */
      12  static uint16_t make_bf16 (float f)
      13  {
      14    union
      15    {
      16      float f;
      17      uint32_t u;
      18    } fu;
      19    fu.f = f;
      20    fu.u = (fu.u >> 16) & 0xffff;
      21    return (uint16_t) fu.u;
      22  }
      23  
      24  static float make_f32 (uint16_t bf)
      25  {
      26    union
      27    {
      28      float f;
      29      uint32_t u;
      30    } fu;
      31    fu.u = (uint32_t) bf << 16;
      32    return fu.f;
      33  }
      34  
      35  /* Init tile buffer with bf16 pairs */
      36  void init_bf16_max_tile_buffer (uint8_t *buf)
      37  { 
      38    int i, j;
      39    uint16_t *ptr = (uint16_t *)buf;
      40  
      41    for(i = 0; i < 16; i++)
      42      for(j = 0; j < 32; j++)
      43        {	
      44  	float f = 16.1f * i + 3.4f * j;
      45  	ptr[i * 32 + j] = make_bf16(f);
      46        }
      47  }
      48  
      49  void calc_matrix_dpbf16ps (__tile *dst, __tile *src1, __tile *src2)
      50  {
      51    uint16_t *src1_buf = (uint16_t *)src1->buf;
      52    uint16_t *src2_buf = (uint16_t *)src2->buf;
      53    float *dst_buf = (float *)dst->buf;
      54    
      55    int M = src1->rows;
      56    int N = src1->colsb / 4;
      57    int K = src2->colsb / 4;
      58    int i, j, k, t;
      59  
      60    for (i = 0; i < M; i++)
      61      for (j = 0; j < N; j++)
      62        for (k = 0; k < K; k++)
      63  	for (t = 0; t < 2; t+=2)
      64  	  {    
      65  	    dst_buf[i * N + k] += 
      66  	      (make_f32(src1_buf[i * 2 * N + 2 * j + t]) *
      67  	      make_f32(src2_buf[j * 2 * K + 2 * k + t])) +
      68  	      (make_f32(src1_buf[i * 2 * N + 2 * j + t + 1]) *
      69  	      make_f32(src2_buf[j * 2 * K + 2 * k + t + 1]));
      70  	  }
      71  
      72  }
      73  
      74  void test_amx_bf16_dpbf16ps ()
      75  {
      76    __tilecfg_u cfg;
      77    __tile dst, dst_ref, src1, src2;
      78    uint8_t tmp_dst_buf[1024];
      79  
      80    init_bf16_max_tile_buffer (tmp_dst_buf);
      81    
      82    init_tile_config (&cfg);
      83    init_tile_reg_and_src_with_buffer (1, dst, tmp_dst_buf);
      84    init_tile_reg_and_src_with_buffer (2, src1, tmp_dst_buf);
      85    init_tile_reg_and_src_with_buffer (3, src2, tmp_dst_buf);
      86  
      87    calc_matrix_dpbf16ps (&dst, &src1, &src2);
      88    
      89    _tile_dpbf16ps (1, 2, 3);
      90    _tile_stored (1, dst_ref.buf, _STRIDE);
      91  
      92    if (!check_float_tile_register (&dst_ref, &dst))
      93          abort();
      94  }