(root)/
glibc-2.38/
sysdeps/
aarch64/
fpu/
scripts/
bench_libmvec_sve.py
       1  #!/usr/bin/python3
       2  # Copyright (C) 2023 Free Software Foundation, Inc.
       3  # This file is part of the GNU C Library.
       4  #
       5  # The GNU C Library is free software; you can redistribute it and/or
       6  # modify it under the terms of the GNU Lesser General Public
       7  # License as published by the Free Software Foundation; either
       8  # version 2.1 of the License, or (at your option) any later version.
       9  #
      10  # The GNU C Library is distributed in the hope that it will be useful,
      11  # but WITHOUT ANY WARRANTY; without even the implied warranty of
      12  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      13  # Lesser General Public License for more details.
      14  #
      15  # You should have received a copy of the GNU Lesser General Public
      16  # License along with the GNU C Library; if not, see
      17  # <https://www.gnu.org/licenses/>.
      18  
      19  import sys
      20  
      21  TEMPLATE = """
      22  #include <math.h>
      23  #include <arm_sve.h>
      24  
      25  #define MAX_STRIDE {max_stride}
      26  #define STRIDE {stride}
      27  #define PTRUE svptrue_b{prec_short}
      28  #define SV_LOAD svld1_f{prec_short}
      29  #define SV_STORE svst1_f{prec_short}
      30  #define REQUIRE_SVE
      31  
      32  #define CALL_BENCH_FUNC(v, i) (__extension__ ({{                              \\
      33     {rtype} mx0 = {fname}(SV_LOAD (PTRUE(), variants[v].in[i].arg0), PTRUE()); \\
      34     mx0; }}))
      35  
      36  struct args
      37  {{
      38    {stype} arg0[MAX_STRIDE];
      39    double timing;
      40  }};
      41  
      42  struct _variants
      43  {{
      44    const char *name;
      45    int count;
      46    const struct args *in;
      47  }};
      48  
      49  static const struct args in0[{rowcount}] = {{
      50  {in_data}
      51  }};
      52  
      53  static const struct _variants variants[1] = {{
      54    {{"", {rowcount}, in0}},
      55  }};
      56  
      57  #define NUM_VARIANTS 1
      58  #define NUM_SAMPLES(i) (variants[i].count)
      59  #define VARIANT(i) (variants[i].name)
      60  
      61  // Cannot pass volatile pointer to svst1. This still does not appear to get optimised out.
      62  static {stype} /*volatile*/ ret[MAX_STRIDE];
      63  
      64  #define BENCH_FUNC(i, j) ({{ SV_STORE(PTRUE(), ret, CALL_BENCH_FUNC(i, j)); }})
      65  #define FUNCNAME "{fname}"
      66  #include <bench-libmvec-skeleton.c>
      67  """
      68  
      69  def main(name):
      70      _, prec, _, func = name.split("-")
      71      scalar_to_sve_type = {"double": "svfloat64_t", "float": "svfloat32_t"}
      72  
      73      stride = {"double": "svcntd()", "float": "svcntw()"}[prec]
      74      rtype = scalar_to_sve_type[prec]
      75      atype = scalar_to_sve_type[prec]
      76      fname = f"_ZGVsMxv_{func}{'f' if prec == 'float' else ''}"
      77      prec_short = {"double": 64, "float": 32}[prec]
      78      # Max SVE vector length is 2048 bits. To ensure benchmarks are
      79      # vector-length-agnostic, but still use as wide vectors as
      80      # possible on any given target, divide input data into 2048-bit
      81      # rows, then load/store as many elements as the target will allow.
      82      max_stride = 2048 // prec_short
      83  
      84      with open(f"../benchtests/libmvec/{func}-inputs") as f:
      85          in_vals = [l.strip() for l in f.readlines() if l and not l.startswith("#")]
      86      in_vals = [in_vals[i:i+max_stride] for i in range(0, len(in_vals), max_stride)]
      87      rowcount= len(in_vals)
      88      in_data = ",\n".join("{{" + ", ".join(row) + "}, 0}" for row in in_vals)
      89  
      90      print(TEMPLATE.format(stride=stride,
      91                            rtype=rtype,
      92                            atype=atype,
      93                            fname=fname,
      94                            prec_short=prec_short,
      95                            in_data=in_data,
      96                            rowcount=rowcount,
      97                            stype=prec,
      98                            max_stride=max_stride))
      99  
     100  
     101  if __name__ == "__main__":
     102      main(sys.argv[1])