1  /* { dg-do run } */
       2  /* { dg-options "-O2 -mavx512vnni -mavx512vl" } */
       3  /* { dg-require-effective-target avx512vnni } */
       4  /* { dg-require-effective-target avx512vl } */
       5  
       6  static void vnni_test (void);
       7  #define DO_TEST vnni_test
       8  #define AVX512VNNI
       9  #define AVX512VL
      10  #include "avx512f-check.h"
      11  #include "vnni-auto-vectorize-1.c"
      12  
      13  #define N 256
      14  unsigned char a_u8[N];
      15  char b_i8[N];
      16  short a_i16[N], b_i16[N];
      17  int i8_exp, i8_ref, i16_exp, i16_ref;
      18  
      19  int __attribute__((noinline, noclone, optimize("no-tree-vectorize")))
      20  sdot_prod_hi_scalar (short * restrict a, short * restrict b,
      21  		     int c, int n)
      22  {
      23    int i;
      24    for (i = 0; i < n; i++)
      25      {
      26        c += ((int) a[i] * (int) b[i]);
      27      }
      28    return c;
      29  }
      30  
      31  int __attribute__((noinline, noclone, optimize("no-tree-vectorize")))
      32  usdot_prod_qi_scalar (unsigned char * restrict a, char *restrict b,
      33  	       int c, int n)
      34  {
      35    int i;
      36    for (i = 0; i < n; i++)
      37      {
      38        c += ((int) a[i] * (int) b[i]);
      39      }
      40    return c;
      41  }
      42  
      43  void init()
      44  {
      45    int i;
      46  
      47    i8_exp = i8_ref = 127;
      48    i16_exp = i16_ref = 65535;
      49  
      50    for (i = 0; i < N; i++)
      51      {
      52        a_u8[i] = (i + 3) % 256;
      53        b_i8[i] = (i + 1) % 128; 
      54        a_i16[i] = i * 2;
      55        b_i16[i] = -i + 2;
      56      }
      57  }
      58  
      59  static void vnni_test()
      60  {
      61    init ();
      62    i16_exp = sdot_prod_hi (a_i16, b_i16, i16_exp, N);
      63    i16_ref = sdot_prod_hi_scalar (a_i16, b_i16, i16_ref, N);
      64    if (i16_exp != i16_ref)
      65      abort ();
      66  
      67    init ();
      68    i8_exp = usdot_prod_qi (a_u8, b_i8, i8_exp, N);
      69    i8_ref = usdot_prod_qi_scalar (a_u8, b_i8, i8_ref, N);
      70    if (i8_exp != i8_ref)
      71      abort ();
      72  }