1 /* { dg-do run { target { ! ia32 } } } */
2 /* { dg-require-effective-target amx_bf16 } */
3 /* { dg-options "-O2 -mamx-bf16" } */
4 #include <immintrin.h>
5
6 #define AMX_BF16
7 #define DO_TEST test_amx_bf16_dpbf16ps
8 void test_amx_bf16_dpbf16ps ();
9 #include "amx-check.h"
10
11 /* Transformation functions between bf16/float */
12 static uint16_t make_bf16 (float f)
13 {
14 union
15 {
16 float f;
17 uint32_t u;
18 } fu;
19 fu.f = f;
20 fu.u = (fu.u >> 16) & 0xffff;
21 return (uint16_t) fu.u;
22 }
23
24 static float make_f32 (uint16_t bf)
25 {
26 union
27 {
28 float f;
29 uint32_t u;
30 } fu;
31 fu.u = (uint32_t) bf << 16;
32 return fu.f;
33 }
34
35 /* Init tile buffer with bf16 pairs */
36 void init_bf16_max_tile_buffer (uint8_t *buf)
37 {
38 int i, j;
39 uint16_t *ptr = (uint16_t *)buf;
40
41 for(i = 0; i < 16; i++)
42 for(j = 0; j < 32; j++)
43 {
44 float f = 16.1f * i + 3.4f * j;
45 ptr[i * 32 + j] = make_bf16(f);
46 }
47 }
48
49 void calc_matrix_dpbf16ps (__tile *dst, __tile *src1, __tile *src2)
50 {
51 uint16_t *src1_buf = (uint16_t *)src1->buf;
52 uint16_t *src2_buf = (uint16_t *)src2->buf;
53 float *dst_buf = (float *)dst->buf;
54
55 int M = src1->rows;
56 int N = src1->colsb / 4;
57 int K = src2->colsb / 4;
58 int i, j, k, t;
59
60 for (i = 0; i < M; i++)
61 for (j = 0; j < N; j++)
62 for (k = 0; k < K; k++)
63 for (t = 0; t < 2; t+=2)
64 {
65 dst_buf[i * N + k] +=
66 (make_f32(src1_buf[i * 2 * N + 2 * j + t]) *
67 make_f32(src2_buf[j * 2 * K + 2 * k + t])) +
68 (make_f32(src1_buf[i * 2 * N + 2 * j + t + 1]) *
69 make_f32(src2_buf[j * 2 * K + 2 * k + t + 1]));
70 }
71
72 }
73
74 void test_amx_bf16_dpbf16ps ()
75 {
76 __tilecfg_u cfg;
77 __tile dst, dst_ref, src1, src2;
78 uint8_t tmp_dst_buf[1024];
79
80 init_bf16_max_tile_buffer (tmp_dst_buf);
81
82 init_tile_config (&cfg);
83 init_tile_reg_and_src_with_buffer (1, dst, tmp_dst_buf);
84 init_tile_reg_and_src_with_buffer (2, src1, tmp_dst_buf);
85 init_tile_reg_and_src_with_buffer (3, src2, tmp_dst_buf);
86
87 calc_matrix_dpbf16ps (&dst, &src1, &src2);
88
89 _tile_dpbf16ps (1, 2, 3);
90 _tile_stored (1, dst_ref.buf, _STRIDE);
91
92 if (!check_float_tile_register (&dst_ref, &dst))
93 abort();
94 }