(root)/
glibc-2.38/
sysdeps/
x86_64/
multiarch/
strstr-avx512.c
       1  /* strstr optimized with 512-bit AVX-512 instructions
       2     Copyright (C) 2022-2023 Free Software Foundation, Inc.
       3     This file is part of the GNU C Library.
       4  
       5     The GNU C Library is free software; you can redistribute it and/or
       6     modify it under the terms of the GNU Lesser General Public
       7     License as published by the Free Software Foundation; either
       8     version 2.1 of the License, or (at your option) any later version.
       9  
      10     The GNU C Library is distributed in the hope that it will be useful,
      11     but WITHOUT ANY WARRANTY; without even the implied warranty of
      12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      13     Lesser General Public License for more details.
      14  
      15     You should have received a copy of the GNU Lesser General Public
      16     License along with the GNU C Library; if not, see
      17     <https://www.gnu.org/licenses/>.  */
      18  
      19  #include <immintrin.h>
      20  #include <inttypes.h>
      21  #include <stdbool.h>
      22  #include <string.h>
      23  
      24  #define FULL_MMASK64 0xffffffffffffffff
      25  #define ONE_64BIT 0x1ull
      26  #define ZMM_SIZE_IN_BYTES 64
      27  #define PAGESIZE 4096
      28  
      29  #define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__)
      30  #define kshiftri_mask64(x, y) ((x) >> (y))
      31  #define kand_mask64(x, y) ((x) & (y))
      32  
      33  /*
      34   Returns the index of the first edge within the needle, returns 0 if no edge
      35   is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
      36   */
      37  static inline size_t
      38  find_edge_in_needle (const char *ned)
      39  {
      40    size_t ind = 0;
      41    while (ned[ind + 1] != '\0')
      42      {
      43        if (ned[ind] != ned[ind + 1])
      44          return ind;
      45        else
      46          ind = ind + 1;
      47      }
      48    return 0;
      49  }
      50  
      51  /*
      52   Compare needle with haystack byte by byte at specified location
      53   */
      54  static inline bool
      55  verify_string_match (const char *hay, const size_t hay_index, const char *ned,
      56                       size_t ind)
      57  {
      58    while (ned[ind] != '\0')
      59      {
      60        if (ned[ind] != hay[hay_index + ind])
      61          return false;
      62        ind = ind + 1;
      63      }
      64    return true;
      65  }
      66  
      67  /*
      68   Compare needle with haystack at specified location. The first 64 bytes are
      69   compared using a ZMM register.
      70   */
      71  static inline bool
      72  verify_string_match_avx512 (const char *hay, const size_t hay_index,
      73                              const char *ned, const __mmask64 ned_mask,
      74                              const __m512i ned_zmm)
      75  {
      76    /* check first 64 bytes using zmm and then scalar */
      77    __m512i hay_zmm = _mm512_loadu_si512 (hay + hay_index); // safe to do so
      78    __mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm);
      79    if (match != 0x0) // failed the first few chars
      80      return false;
      81    else if (ned_mask == FULL_MMASK64)
      82      return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES);
      83    return true;
      84  }
      85  
      86  char *
      87  __strstr_avx512 (const char *haystack, const char *ned)
      88  {
      89    char first = ned[0];
      90    if (first == '\0')
      91      return (char *)haystack;
      92    if (ned[1] == '\0')
      93      return (char *)strchr (haystack, ned[0]);
      94  
      95    size_t edge = find_edge_in_needle (ned);
      96  
      97    /* ensure haystack is as long as the pos of edge in needle */
      98    for (int ii = 0; ii < edge; ++ii)
      99      {
     100        if (haystack[ii] == '\0')
     101          return NULL;
     102      }
     103  
     104    /*
     105     Load 64 bytes of the needle and save it to a zmm register
     106     Read one cache line at a time to avoid loading across a page boundary
     107     */
     108    __mmask64 ned_load_mask = _bzhi_u64 (
     109        FULL_MMASK64, 64 - ((uintptr_t) (ned) & 63));
     110    __m512i ned_zmm = _mm512_maskz_loadu_epi8 (ned_load_mask, ned);
     111    __mmask64 ned_nullmask
     112        = _mm512_mask_testn_epi8_mask (ned_load_mask, ned_zmm, ned_zmm);
     113  
     114    if (__glibc_unlikely (ned_nullmask == 0x0))
     115      {
     116        ned_zmm = _mm512_loadu_si512 (ned);
     117        ned_nullmask = _mm512_testn_epi8_mask (ned_zmm, ned_zmm);
     118        ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
     119        if (ned_nullmask != 0x0)
     120          ned_load_mask = ned_load_mask >> 1;
     121      }
     122    else
     123      {
     124        ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
     125        ned_load_mask = ned_load_mask >> 1;
     126      }
     127    const __m512i ned0 = _mm512_set1_epi8 (ned[edge]);
     128    const __m512i ned1 = _mm512_set1_epi8 (ned[edge + 1]);
     129  
     130    /*
     131     Read the bytes of haystack in the current cache line
     132     */
     133    size_t hay_index = edge;
     134    __mmask64 loadmask = _bzhi_u64 (
     135        FULL_MMASK64, 64 - ((uintptr_t) (haystack + hay_index) & 63));
     136    /* First load is a partial cache line */
     137    __m512i hay0 = _mm512_maskz_loadu_epi8 (loadmask, haystack + hay_index);
     138    /* Search for NULL and compare only till null char */
     139    uint64_t nullmask
     140        = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0));
     141    uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT);
     142    cmpmask = cmpmask & cvtmask64_u64 (loadmask);
     143    /* Search for the 2 characters of needle */
     144    __mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
     145    __mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1);
     146    k1 = kshiftri_mask64 (k1, 1);
     147    /* k2 masks tell us if both chars from needle match */
     148    uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
     149    /* For every match, search for the entire needle for a full match */
     150    while (k2)
     151      {
     152        uint64_t bitcount = _tzcnt_u64 (k2);
     153        k2 = _blsr_u64 (k2);
     154        size_t match_pos = hay_index + bitcount - edge;
     155        if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
     156            < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
     157          {
     158            /*
     159             * Use vector compare as long as you are not crossing a page
     160             */
     161            if (verify_string_match_avx512 (haystack, match_pos, ned,
     162                                            ned_load_mask, ned_zmm))
     163              return (char *)haystack + match_pos;
     164          }
     165        else
     166          {
     167            if (verify_string_match (haystack, match_pos, ned, 0))
     168              return (char *)haystack + match_pos;
     169          }
     170      }
     171    /* We haven't checked for potential match at the last char yet */
     172    haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63));
     173    hay_index = 0;
     174  
     175    /*
     176     Loop over one cache line at a time to prevent reading over page
     177     boundary
     178     */
     179    __m512i hay1;
     180    while (nullmask == 0)
     181      {
     182        hay0 = _mm512_loadu_si512 (haystack + hay_index);
     183        hay1 = _mm512_load_si512 (haystack + hay_index
     184                                  + 1); // Always 64 byte aligned
     185        nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1));
     186        /* Compare only till null char */
     187        cmpmask = nullmask ^ (nullmask - ONE_64BIT);
     188        k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
     189        k1 = _mm512_cmpeq_epi8_mask (hay1, ned1);
     190        /* k2 masks tell us if both chars from needle match */
     191        k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
     192        /* For every match, compare full strings for potential match */
     193        while (k2)
     194          {
     195            uint64_t bitcount = _tzcnt_u64 (k2);
     196            k2 = _blsr_u64 (k2);
     197            size_t match_pos = hay_index + bitcount - edge;
     198            if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
     199                < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
     200              {
     201                /*
     202                 * Use vector compare as long as you are not crossing a page
     203                 */
     204                if (verify_string_match_avx512 (haystack, match_pos, ned,
     205                                                ned_load_mask, ned_zmm))
     206                  return (char *)haystack + match_pos;
     207              }
     208            else
     209              {
     210                /* Compare byte by byte */
     211                if (verify_string_match (haystack, match_pos, ned, 0))
     212                  return (char *)haystack + match_pos;
     213              }
     214          }
     215        hay_index += ZMM_SIZE_IN_BYTES;
     216      }
     217    return NULL;
     218  }