1  """bytecode_helper - support tools for testing correct bytecode generation"""
       2  
       3  import unittest
       4  import dis
       5  import io
       6  from _testinternalcapi import compiler_codegen, optimize_cfg, assemble_code_object
       7  
       8  _UNSPECIFIED = object()
       9  
      10  class ESC[4;38;5;81mBytecodeTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      11      """Custom assertion methods for inspecting bytecode."""
      12  
      13      def get_disassembly_as_string(self, co):
      14          s = io.StringIO()
      15          dis.dis(co, file=s)
      16          return s.getvalue()
      17  
      18      def assertInBytecode(self, x, opname, argval=_UNSPECIFIED):
      19          """Returns instr if opname is found, otherwise throws AssertionError"""
      20          self.assertIn(opname, dis.opmap)
      21          for instr in dis.get_instructions(x):
      22              if instr.opname == opname:
      23                  if argval is _UNSPECIFIED or instr.argval == argval:
      24                      return instr
      25          disassembly = self.get_disassembly_as_string(x)
      26          if argval is _UNSPECIFIED:
      27              msg = '%s not found in bytecode:\n%s' % (opname, disassembly)
      28          else:
      29              msg = '(%s,%r) not found in bytecode:\n%s'
      30              msg = msg % (opname, argval, disassembly)
      31          self.fail(msg)
      32  
      33      def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED):
      34          """Throws AssertionError if opname is found"""
      35          self.assertIn(opname, dis.opmap)
      36          for instr in dis.get_instructions(x):
      37              if instr.opname == opname:
      38                  disassembly = self.get_disassembly_as_string(x)
      39                  if argval is _UNSPECIFIED:
      40                      msg = '%s occurs in bytecode:\n%s' % (opname, disassembly)
      41                      self.fail(msg)
      42                  elif instr.argval == argval:
      43                      msg = '(%s,%r) occurs in bytecode:\n%s'
      44                      msg = msg % (opname, argval, disassembly)
      45                      self.fail(msg)
      46  
      47  class ESC[4;38;5;81mCompilationStepTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      48  
      49      HAS_ARG = set(dis.hasarg)
      50      HAS_TARGET = set(dis.hasjrel + dis.hasjabs + dis.hasexc)
      51      HAS_ARG_OR_TARGET = HAS_ARG.union(HAS_TARGET)
      52  
      53      class ESC[4;38;5;81mLabel:
      54          pass
      55  
      56      def assertInstructionsMatch(self, actual_, expected_):
      57          # get two lists where each entry is a label or
      58          # an instruction tuple. Normalize the labels to the
      59          # instruction count of the target, and compare the lists.
      60  
      61          self.assertIsInstance(actual_, list)
      62          self.assertIsInstance(expected_, list)
      63  
      64          actual = self.normalize_insts(actual_)
      65          expected = self.normalize_insts(expected_)
      66          self.assertEqual(len(actual), len(expected))
      67  
      68          # compare instructions
      69          for act, exp in zip(actual, expected):
      70              if isinstance(act, int):
      71                  self.assertEqual(exp, act)
      72                  continue
      73              self.assertIsInstance(exp, tuple)
      74              self.assertIsInstance(act, tuple)
      75              # crop comparison to the provided expected values
      76              if len(act) > len(exp):
      77                  act = act[:len(exp)]
      78              self.assertEqual(exp, act)
      79  
      80      def resolveAndRemoveLabels(self, insts):
      81          idx = 0
      82          res = []
      83          for item in insts:
      84              assert isinstance(item, (self.Label, tuple))
      85              if isinstance(item, self.Label):
      86                  item.value = idx
      87              else:
      88                  idx += 1
      89                  res.append(item)
      90  
      91          return res
      92  
      93      def normalize_insts(self, insts):
      94          """ Map labels to instruction index.
      95              Map opcodes to opnames.
      96          """
      97          insts = self.resolveAndRemoveLabels(insts)
      98          res = []
      99          for item in insts:
     100              assert isinstance(item, tuple)
     101              opcode, oparg, *loc = item
     102              opcode = dis.opmap.get(opcode, opcode)
     103              if isinstance(oparg, self.Label):
     104                  arg = oparg.value
     105              else:
     106                  arg = oparg if opcode in self.HAS_ARG else None
     107              opcode = dis.opname[opcode]
     108              res.append((opcode, arg, *loc))
     109          return res
     110  
     111      def complete_insts_info(self, insts):
     112          # fill in omitted fields in location, and oparg 0 for ops with no arg.
     113          res = []
     114          for item in insts:
     115              assert isinstance(item, tuple)
     116              inst = list(item)
     117              opcode = dis.opmap[inst[0]]
     118              oparg = inst[1]
     119              loc = inst[2:] + [-1] * (6 - len(inst))
     120              res.append((opcode, oparg, *loc))
     121          return res
     122  
     123  
     124  class ESC[4;38;5;81mCodegenTestCase(ESC[4;38;5;149mCompilationStepTestCase):
     125  
     126      def generate_code(self, ast):
     127          insts, _ = compiler_codegen(ast, "my_file.py", 0)
     128          return insts
     129  
     130  
     131  class ESC[4;38;5;81mCfgOptimizationTestCase(ESC[4;38;5;149mCompilationStepTestCase):
     132  
     133      def get_optimized(self, insts, consts, nlocals=0):
     134          insts = self.normalize_insts(insts)
     135          insts = self.complete_insts_info(insts)
     136          insts = optimize_cfg(insts, consts, nlocals)
     137          return insts, consts
     138  
     139  class ESC[4;38;5;81mAssemblerTestCase(ESC[4;38;5;149mCompilationStepTestCase):
     140  
     141      def get_code_object(self, filename, insts, metadata):
     142          co = assemble_code_object(filename, insts, metadata)
     143          return co