1 #ifndef AMX_CHECK_H_INCLUDED
2 #define AMX_CHECK_H_INCLUDED
3
4 #include <stdlib.h>
5 #include <string.h>
6 #include <stdint.h>
7 #include <unistd.h>
8 #ifdef __linux__
9 #include <sys/syscall.h>
10 #endif
11 #ifdef DEBUG
12 #include <stdio.h>
13 #endif
14 #include "cpuid.h"
15
16 #define XFEATURE_XTILECFG 17
17 #define XFEATURE_XTILEDATA 18
18 #define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
19 #define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
20 #define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
21
22 #define ARCH_GET_XCOMP_PERM 0x1022
23 #define ARCH_REQ_XCOMP_PERM 0x1023
24
25 /* TODO: The tmm emulation is temporary for current
26 AMX implementation with no tmm regclass, should
27 be changed in the future. */
28 typedef struct __tile_config
29 {
30 uint8_t palette_id;
31 uint8_t start_row;
32 uint8_t reserved_0[14];
33 uint16_t colsb[8]; /* Colum size of each tmm register in bytes */
34 uint16_t reserved_1[8];
35 uint8_t rows[8]; /* Row size of each tmm reg in bytes */
36 uint8_t reserved_2[8];
37 } __tilecfg;
38
39 typedef union __union_tile_config
40 {
41 __tilecfg s;
42 uint8_t a[64];
43 } __tilecfg_u;
44
45 typedef struct __tile
46 {
47 /* Max size of tile register */
48 uint8_t buf[1024];
49 int rows;
50 int colsb;
51 } __tile;
52
53 /* Maxium col/row size in bytes */
54 #define MAX_ROWS 16
55 #define MAX_COLS 64
56
57 /* Stride (colum width in byte) used for tileload/store */
58 #define _STRIDE 64
59
60 #ifdef __linux__
61 /* We need syscall to use amx functions */
62 int request_perm_xtile_data()
63 {
64 unsigned long bitmask;
65
66 if (syscall (SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA) ||
67 syscall (SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask))
68 return 0;
69
70 return (bitmask & XFEATURE_MASK_XTILE) != 0;
71 }
72 #endif
73
74 /* Initialize tile config by setting all tmm size to 16x64 */
75 void init_tile_config (__tilecfg_u *dst)
76 {
77 int i;
78
79 dst->s.palette_id = 1;
80 dst->s.start_row = 0;
81
82 for (i = 0; i < 14; i++)
83 dst->s.reserved_0[i] = 0;
84
85 for (i = 0; i < 8; i++)
86 {
87 dst->s.colsb[i] = _STRIDE;
88 dst->s.rows[i] = 16;
89 dst->s.reserved_1[i] = 0;
90 dst->s.reserved_2[i] = 0;
91 }
92
93 _tile_loadconfig (dst->a);
94 }
95
96 /* Init __tile variable that going to be store to register
97 w/o extra buffer. If buffer exists, it should be the same
98 size matrix as corresponding tmm register.
99 Should execute init_tile_config first */
100 void init_tile_src (const int tmm_num, __tile *src, uint8_t *buffer)
101 {
102 int rows, colsb, i, j;
103 __tilecfg_u tmp;
104
105 _tile_storeconfig (tmp.a);
106
107 src->rows = rows = tmp.s.rows[tmm_num];
108 src->colsb = colsb = tmp.s.colsb[tmm_num];
109
110 for (i = 0; i < rows; i++)
111 for (j = 0; j < colsb; j++)
112 {
113 if(buffer)
114 src->buf[i * colsb + j] = buffer[i * colsb + j];
115 else
116 src->buf[i * colsb + j] = (i + 11 * j) % 256;
117 }
118
119 }
120
121 /* Init __tile src and corresponding tmm register */
122 #define init_tile_reg_and_src(tmm_num, src) \
123 { \
124 init_tile_src (tmm_num, &src, NULL); \
125 _tile_loadd (tmm_num, src.buf, _STRIDE); \
126 }
127
128 #define init_tile_reg_and_src_with_buffer(tmm_num, src, buffer) \
129 { \
130 init_tile_src (tmm_num, &src, buffer); \
131 _tile_loadd (tmm_num, src.buf, _STRIDE); \
132 }
133
134 /* Zero __tile src. It should be init first. */
135 void zero_tile_src (__tile *src)
136 {
137 int i, j;
138
139 for (i = 0; i < src->rows; i++)
140 for (j = 0; j < src->colsb; j++)
141 src->buf[i * src->colsb + j] = 0;
142 }
143
144 /* Compare tile config value with __tilecfg_u dst */
145 int check_tile_config (__tilecfg_u *src, __tilecfg_u *dst)
146 {
147 size_t size = sizeof(__tilecfg);
148 uint8_t *pa_src = (uint8_t *) src->a;
149 uint8_t *pa_dst = (uint8_t *) dst->a;
150
151 for (int i = 0; i < size; i++)
152 if (pa_src[i] != pa_dst[i])
153 return 0;
154
155 return 1;
156 }
157
158 /* Compare tile register value with __tile variable */
159 int check_tile_register (__tile* ref, __tile* target)
160 {
161 /* Tile register should be stored from tmm to
162 memory and compare with emulation results. */
163 int rows = target->rows;
164 int colsb = target->colsb;
165 int i, j;
166
167 for (i = 0; i < rows; i++)
168 for (j = 0; j < colsb; j++)
169 if (ref->buf[i * colsb + j] != target->buf[i * colsb + j])
170 return 0;
171
172 return 1;
173 }
174
175 /* Compare float tile register value with __tile variable */
176 int check_float_tile_register (__tile* ref, __tile* target)
177 {
178 /* Tile register should be stored from tmm to
179 memory and compare with emulation results. */
180 int rows = target->rows;
181 int colsb = target->colsb / 4;
182 int i, j;
183 uint32_t *ref_buf = (uint32_t *) ref->buf;
184 uint32_t *target_buf = (uint32_t *) target->buf;
185
186 for (i = 0; i < rows; i++)
187 for (j = 0; j < colsb; j++)
188 if (abs(ref_buf[i * colsb + j] - target_buf[i * colsb + j]) > 1)
189 return 0;
190
191 return 1;
192 }
193
194 #ifndef DO_TEST
195 #define DO_TEST do_test
196 static void test_amx (void);
197 __attribute__ ((noinline))
198 static void
199 do_test (void)
200 {
201 test_amx ();
202 }
203 #endif
204
205 int
206 main ()
207 {
208 /* Check cpu support for AMX */
209 if (__builtin_cpu_supports ("amx-tile")
210 #ifdef AMX_INT8
211 && __builtin_cpu_supports ("amx-int8")
212 #endif
213 #ifdef AMX_BF16
214 && __builtin_cpu_supports ("amx-bf16")
215 #endif
216 #ifdef AMX_FP16
217 && __builtin_cpu_supports ("amx-fp16")
218 #endif
219 #ifdef AMX_COMPLEX
220 && __builtin_cpu_supports ("amx-complex")
221 #endif
222 #ifdef __linux__
223 && request_perm_xtile_data ()
224 #endif
225 )
226 {
227 DO_TEST ();
228 #ifdef DEBUG
229 printf ("PASSED\n");
230 #endif
231 }
232 #ifdef DEBUG
233 else
234 printf ("SKIPPED\n");
235 #endif
236
237 return 0;
238 }
239
240 #endif