(root)/
Python-3.11.7/
Lib/
unittest/
loader.py
       1  """Loading unittests."""
       2  
       3  import os
       4  import re
       5  import sys
       6  import traceback
       7  import types
       8  import functools
       9  import warnings
      10  
      11  from fnmatch import fnmatch, fnmatchcase
      12  
      13  from . import case, suite, util
      14  
      15  __unittest = True
      16  
      17  # what about .pyc (etc)
      18  # we would need to avoid loading the same tests multiple times
      19  # from '.py', *and* '.pyc'
      20  VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
      21  
      22  
      23  class ESC[4;38;5;81m_FailedTest(ESC[4;38;5;149mcaseESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      24      _testMethodName = None
      25  
      26      def __init__(self, method_name, exception):
      27          self._exception = exception
      28          super(_FailedTest, self).__init__(method_name)
      29  
      30      def __getattr__(self, name):
      31          if name != self._testMethodName:
      32              return super(_FailedTest, self).__getattr__(name)
      33          def testFailure():
      34              raise self._exception
      35          return testFailure
      36  
      37  
      38  def _make_failed_import_test(name, suiteClass):
      39      message = 'Failed to import test module: %s\n%s' % (
      40          name, traceback.format_exc())
      41      return _make_failed_test(name, ImportError(message), suiteClass, message)
      42  
      43  def _make_failed_load_tests(name, exception, suiteClass):
      44      message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
      45      return _make_failed_test(
      46          name, exception, suiteClass, message)
      47  
      48  def _make_failed_test(methodname, exception, suiteClass, message):
      49      test = _FailedTest(methodname, exception)
      50      return suiteClass((test,)), message
      51  
      52  def _make_skipped_test(methodname, exception, suiteClass):
      53      @case.skip(str(exception))
      54      def testSkipped(self):
      55          pass
      56      attrs = {methodname: testSkipped}
      57      TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
      58      return suiteClass((TestClass(methodname),))
      59  
      60  def _jython_aware_splitext(path):
      61      if path.lower().endswith('$py.class'):
      62          return path[:-9]
      63      return os.path.splitext(path)[0]
      64  
      65  
      66  class ESC[4;38;5;81mTestLoader(ESC[4;38;5;149mobject):
      67      """
      68      This class is responsible for loading tests according to various criteria
      69      and returning them wrapped in a TestSuite
      70      """
      71      testMethodPrefix = 'test'
      72      sortTestMethodsUsing = staticmethod(util.three_way_cmp)
      73      testNamePatterns = None
      74      suiteClass = suite.TestSuite
      75      _top_level_dir = None
      76  
      77      def __init__(self):
      78          super(TestLoader, self).__init__()
      79          self.errors = []
      80          # Tracks packages which we have called into via load_tests, to
      81          # avoid infinite re-entrancy.
      82          self._loading_packages = set()
      83  
      84      def loadTestsFromTestCase(self, testCaseClass):
      85          """Return a suite of all test cases contained in testCaseClass"""
      86          if issubclass(testCaseClass, suite.TestSuite):
      87              raise TypeError("Test cases should not be derived from "
      88                              "TestSuite. Maybe you meant to derive from "
      89                              "TestCase?")
      90          if testCaseClass in (case.TestCase, case.FunctionTestCase):
      91              # We don't load any tests from base types that should not be loaded.
      92              testCaseNames = []
      93          else:
      94              testCaseNames = self.getTestCaseNames(testCaseClass)
      95              if not testCaseNames and hasattr(testCaseClass, 'runTest'):
      96                  testCaseNames = ['runTest']
      97          loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
      98          return loaded_suite
      99  
     100      # XXX After Python 3.5, remove backward compatibility hacks for
     101      # use_load_tests deprecation via *args and **kws.  See issue 16662.
     102      def loadTestsFromModule(self, module, *args, pattern=None, **kws):
     103          """Return a suite of all test cases contained in the given module"""
     104          # This method used to take an undocumented and unofficial
     105          # use_load_tests argument.  For backward compatibility, we still
     106          # accept the argument (which can also be the first position) but we
     107          # ignore it and issue a deprecation warning if it's present.
     108          if len(args) > 0 or 'use_load_tests' in kws:
     109              warnings.warn('use_load_tests is deprecated and ignored',
     110                            DeprecationWarning)
     111              kws.pop('use_load_tests', None)
     112          if len(args) > 1:
     113              # Complain about the number of arguments, but don't forget the
     114              # required `module` argument.
     115              complaint = len(args) + 1
     116              raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
     117          if len(kws) != 0:
     118              # Since the keyword arguments are unsorted (see PEP 468), just
     119              # pick the alphabetically sorted first argument to complain about,
     120              # if multiple were given.  At least the error message will be
     121              # predictable.
     122              complaint = sorted(kws)[0]
     123              raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
     124          tests = []
     125          for name in dir(module):
     126              obj = getattr(module, name)
     127              if (
     128                  isinstance(obj, type)
     129                  and issubclass(obj, case.TestCase)
     130                  and obj not in (case.TestCase, case.FunctionTestCase)
     131              ):
     132                  tests.append(self.loadTestsFromTestCase(obj))
     133  
     134          load_tests = getattr(module, 'load_tests', None)
     135          tests = self.suiteClass(tests)
     136          if load_tests is not None:
     137              try:
     138                  return load_tests(self, tests, pattern)
     139              except Exception as e:
     140                  error_case, error_message = _make_failed_load_tests(
     141                      module.__name__, e, self.suiteClass)
     142                  self.errors.append(error_message)
     143                  return error_case
     144          return tests
     145  
     146      def loadTestsFromName(self, name, module=None):
     147          """Return a suite of all test cases given a string specifier.
     148  
     149          The name may resolve either to a module, a test case class, a
     150          test method within a test case class, or a callable object which
     151          returns a TestCase or TestSuite instance.
     152  
     153          The method optionally resolves the names relative to a given module.
     154          """
     155          parts = name.split('.')
     156          error_case, error_message = None, None
     157          if module is None:
     158              parts_copy = parts[:]
     159              while parts_copy:
     160                  try:
     161                      module_name = '.'.join(parts_copy)
     162                      module = __import__(module_name)
     163                      break
     164                  except ImportError:
     165                      next_attribute = parts_copy.pop()
     166                      # Last error so we can give it to the user if needed.
     167                      error_case, error_message = _make_failed_import_test(
     168                          next_attribute, self.suiteClass)
     169                      if not parts_copy:
     170                          # Even the top level import failed: report that error.
     171                          self.errors.append(error_message)
     172                          return error_case
     173              parts = parts[1:]
     174          obj = module
     175          for part in parts:
     176              try:
     177                  parent, obj = obj, getattr(obj, part)
     178              except AttributeError as e:
     179                  # We can't traverse some part of the name.
     180                  if (getattr(obj, '__path__', None) is not None
     181                      and error_case is not None):
     182                      # This is a package (no __path__ per importlib docs), and we
     183                      # encountered an error importing something. We cannot tell
     184                      # the difference between package.WrongNameTestClass and
     185                      # package.wrong_module_name so we just report the
     186                      # ImportError - it is more informative.
     187                      self.errors.append(error_message)
     188                      return error_case
     189                  else:
     190                      # Otherwise, we signal that an AttributeError has occurred.
     191                      error_case, error_message = _make_failed_test(
     192                          part, e, self.suiteClass,
     193                          'Failed to access attribute:\n%s' % (
     194                              traceback.format_exc(),))
     195                      self.errors.append(error_message)
     196                      return error_case
     197  
     198          if isinstance(obj, types.ModuleType):
     199              return self.loadTestsFromModule(obj)
     200          elif (
     201              isinstance(obj, type)
     202              and issubclass(obj, case.TestCase)
     203              and obj not in (case.TestCase, case.FunctionTestCase)
     204          ):
     205              return self.loadTestsFromTestCase(obj)
     206          elif (isinstance(obj, types.FunctionType) and
     207                isinstance(parent, type) and
     208                issubclass(parent, case.TestCase)):
     209              name = parts[-1]
     210              inst = parent(name)
     211              # static methods follow a different path
     212              if not isinstance(getattr(inst, name), types.FunctionType):
     213                  return self.suiteClass([inst])
     214          elif isinstance(obj, suite.TestSuite):
     215              return obj
     216          if callable(obj):
     217              test = obj()
     218              if isinstance(test, suite.TestSuite):
     219                  return test
     220              elif isinstance(test, case.TestCase):
     221                  return self.suiteClass([test])
     222              else:
     223                  raise TypeError("calling %s returned %s, not a test" %
     224                                  (obj, test))
     225          else:
     226              raise TypeError("don't know how to make test from: %s" % obj)
     227  
     228      def loadTestsFromNames(self, names, module=None):
     229          """Return a suite of all test cases found using the given sequence
     230          of string specifiers. See 'loadTestsFromName()'.
     231          """
     232          suites = [self.loadTestsFromName(name, module) for name in names]
     233          return self.suiteClass(suites)
     234  
     235      def getTestCaseNames(self, testCaseClass):
     236          """Return a sorted sequence of method names found within testCaseClass
     237          """
     238          def shouldIncludeMethod(attrname):
     239              if not attrname.startswith(self.testMethodPrefix):
     240                  return False
     241              testFunc = getattr(testCaseClass, attrname)
     242              if not callable(testFunc):
     243                  return False
     244              fullName = f'%s.%s.%s' % (
     245                  testCaseClass.__module__, testCaseClass.__qualname__, attrname
     246              )
     247              return self.testNamePatterns is None or \
     248                  any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
     249          testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
     250          if self.sortTestMethodsUsing:
     251              testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
     252          return testFnNames
     253  
     254      def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
     255          """Find and return all test modules from the specified start
     256          directory, recursing into subdirectories to find them and return all
     257          tests found within them. Only test files that match the pattern will
     258          be loaded. (Using shell style pattern matching.)
     259  
     260          All test modules must be importable from the top level of the project.
     261          If the start directory is not the top level directory then the top
     262          level directory must be specified separately.
     263  
     264          If a test package name (directory with '__init__.py') matches the
     265          pattern then the package will be checked for a 'load_tests' function. If
     266          this exists then it will be called with (loader, tests, pattern) unless
     267          the package has already had load_tests called from the same discovery
     268          invocation, in which case the package module object is not scanned for
     269          tests - this ensures that when a package uses discover to further
     270          discover child tests that infinite recursion does not happen.
     271  
     272          If load_tests exists then discovery does *not* recurse into the package,
     273          load_tests is responsible for loading all tests in the package.
     274  
     275          The pattern is deliberately not stored as a loader attribute so that
     276          packages can continue discovery themselves. top_level_dir is stored so
     277          load_tests does not need to pass this argument in to loader.discover().
     278  
     279          Paths are sorted before being imported to ensure reproducible execution
     280          order even on filesystems with non-alphabetical ordering like ext3/4.
     281          """
     282          set_implicit_top = False
     283          if top_level_dir is None and self._top_level_dir is not None:
     284              # make top_level_dir optional if called from load_tests in a package
     285              top_level_dir = self._top_level_dir
     286          elif top_level_dir is None:
     287              set_implicit_top = True
     288              top_level_dir = start_dir
     289  
     290          top_level_dir = os.path.abspath(top_level_dir)
     291  
     292          if not top_level_dir in sys.path:
     293              # all test modules must be importable from the top level directory
     294              # should we *unconditionally* put the start directory in first
     295              # in sys.path to minimise likelihood of conflicts between installed
     296              # modules and development versions?
     297              sys.path.insert(0, top_level_dir)
     298          self._top_level_dir = top_level_dir
     299  
     300          is_not_importable = False
     301          if os.path.isdir(os.path.abspath(start_dir)):
     302              start_dir = os.path.abspath(start_dir)
     303              if start_dir != top_level_dir:
     304                  is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
     305          else:
     306              # support for discovery from dotted module names
     307              try:
     308                  __import__(start_dir)
     309              except ImportError:
     310                  is_not_importable = True
     311              else:
     312                  the_module = sys.modules[start_dir]
     313                  top_part = start_dir.split('.')[0]
     314                  try:
     315                      start_dir = os.path.abspath(
     316                          os.path.dirname((the_module.__file__)))
     317                  except AttributeError:
     318                      if the_module.__name__ in sys.builtin_module_names:
     319                          # builtin module
     320                          raise TypeError('Can not use builtin modules '
     321                                          'as dotted module names') from None
     322                      else:
     323                          raise TypeError(
     324                              f"don't know how to discover from {the_module!r}"
     325                              ) from None
     326  
     327                  if set_implicit_top:
     328                      self._top_level_dir = self._get_directory_containing_module(top_part)
     329                      sys.path.remove(top_level_dir)
     330  
     331          if is_not_importable:
     332              raise ImportError('Start directory is not importable: %r' % start_dir)
     333  
     334          tests = list(self._find_tests(start_dir, pattern))
     335          return self.suiteClass(tests)
     336  
     337      def _get_directory_containing_module(self, module_name):
     338          module = sys.modules[module_name]
     339          full_path = os.path.abspath(module.__file__)
     340  
     341          if os.path.basename(full_path).lower().startswith('__init__.py'):
     342              return os.path.dirname(os.path.dirname(full_path))
     343          else:
     344              # here we have been given a module rather than a package - so
     345              # all we can do is search the *same* directory the module is in
     346              # should an exception be raised instead
     347              return os.path.dirname(full_path)
     348  
     349      def _get_name_from_path(self, path):
     350          if path == self._top_level_dir:
     351              return '.'
     352          path = _jython_aware_splitext(os.path.normpath(path))
     353  
     354          _relpath = os.path.relpath(path, self._top_level_dir)
     355          assert not os.path.isabs(_relpath), "Path must be within the project"
     356          assert not _relpath.startswith('..'), "Path must be within the project"
     357  
     358          name = _relpath.replace(os.path.sep, '.')
     359          return name
     360  
     361      def _get_module_from_name(self, name):
     362          __import__(name)
     363          return sys.modules[name]
     364  
     365      def _match_path(self, path, full_path, pattern):
     366          # override this method to use alternative matching strategy
     367          return fnmatch(path, pattern)
     368  
     369      def _find_tests(self, start_dir, pattern):
     370          """Used by discovery. Yields test suites it loads."""
     371          # Handle the __init__ in this package
     372          name = self._get_name_from_path(start_dir)
     373          # name is '.' when start_dir == top_level_dir (and top_level_dir is by
     374          # definition not a package).
     375          if name != '.' and name not in self._loading_packages:
     376              # name is in self._loading_packages while we have called into
     377              # loadTestsFromModule with name.
     378              tests, should_recurse = self._find_test_path(start_dir, pattern)
     379              if tests is not None:
     380                  yield tests
     381              if not should_recurse:
     382                  # Either an error occurred, or load_tests was used by the
     383                  # package.
     384                  return
     385          # Handle the contents.
     386          paths = sorted(os.listdir(start_dir))
     387          for path in paths:
     388              full_path = os.path.join(start_dir, path)
     389              tests, should_recurse = self._find_test_path(full_path, pattern)
     390              if tests is not None:
     391                  yield tests
     392              if should_recurse:
     393                  # we found a package that didn't use load_tests.
     394                  name = self._get_name_from_path(full_path)
     395                  self._loading_packages.add(name)
     396                  try:
     397                      yield from self._find_tests(full_path, pattern)
     398                  finally:
     399                      self._loading_packages.discard(name)
     400  
     401      def _find_test_path(self, full_path, pattern):
     402          """Used by discovery.
     403  
     404          Loads tests from a single file, or a directories' __init__.py when
     405          passed the directory.
     406  
     407          Returns a tuple (None_or_tests_from_file, should_recurse).
     408          """
     409          basename = os.path.basename(full_path)
     410          if os.path.isfile(full_path):
     411              if not VALID_MODULE_NAME.match(basename):
     412                  # valid Python identifiers only
     413                  return None, False
     414              if not self._match_path(basename, full_path, pattern):
     415                  return None, False
     416              # if the test file matches, load it
     417              name = self._get_name_from_path(full_path)
     418              try:
     419                  module = self._get_module_from_name(name)
     420              except case.SkipTest as e:
     421                  return _make_skipped_test(name, e, self.suiteClass), False
     422              except:
     423                  error_case, error_message = \
     424                      _make_failed_import_test(name, self.suiteClass)
     425                  self.errors.append(error_message)
     426                  return error_case, False
     427              else:
     428                  mod_file = os.path.abspath(
     429                      getattr(module, '__file__', full_path))
     430                  realpath = _jython_aware_splitext(
     431                      os.path.realpath(mod_file))
     432                  fullpath_noext = _jython_aware_splitext(
     433                      os.path.realpath(full_path))
     434                  if realpath.lower() != fullpath_noext.lower():
     435                      module_dir = os.path.dirname(realpath)
     436                      mod_name = _jython_aware_splitext(
     437                          os.path.basename(full_path))
     438                      expected_dir = os.path.dirname(full_path)
     439                      msg = ("%r module incorrectly imported from %r. Expected "
     440                             "%r. Is this module globally installed?")
     441                      raise ImportError(
     442                          msg % (mod_name, module_dir, expected_dir))
     443                  return self.loadTestsFromModule(module, pattern=pattern), False
     444          elif os.path.isdir(full_path):
     445              if not os.path.isfile(os.path.join(full_path, '__init__.py')):
     446                  return None, False
     447  
     448              load_tests = None
     449              tests = None
     450              name = self._get_name_from_path(full_path)
     451              try:
     452                  package = self._get_module_from_name(name)
     453              except case.SkipTest as e:
     454                  return _make_skipped_test(name, e, self.suiteClass), False
     455              except:
     456                  error_case, error_message = \
     457                      _make_failed_import_test(name, self.suiteClass)
     458                  self.errors.append(error_message)
     459                  return error_case, False
     460              else:
     461                  load_tests = getattr(package, 'load_tests', None)
     462                  # Mark this package as being in load_tests (possibly ;))
     463                  self._loading_packages.add(name)
     464                  try:
     465                      tests = self.loadTestsFromModule(package, pattern=pattern)
     466                      if load_tests is not None:
     467                          # loadTestsFromModule(package) has loaded tests for us.
     468                          return tests, False
     469                      return tests, True
     470                  finally:
     471                      self._loading_packages.discard(name)
     472          else:
     473              return None, False
     474  
     475  
     476  defaultTestLoader = TestLoader()
     477  
     478  
     479  # These functions are considered obsolete for long time.
     480  # They will be removed in Python 3.13.
     481  
     482  def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None):
     483      loader = TestLoader()
     484      loader.sortTestMethodsUsing = sortUsing
     485      loader.testMethodPrefix = prefix
     486      loader.testNamePatterns = testNamePatterns
     487      if suiteClass:
     488          loader.suiteClass = suiteClass
     489      return loader
     490  
     491  def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None):
     492      import warnings
     493      warnings.warn(
     494          "unittest.getTestCaseNames() is deprecated and will be removed in Python 3.13. "
     495          "Please use unittest.TestLoader.getTestCaseNames() instead.",
     496          DeprecationWarning, stacklevel=2
     497      )
     498      return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)
     499  
     500  def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
     501                suiteClass=suite.TestSuite):
     502      import warnings
     503      warnings.warn(
     504          "unittest.makeSuite() is deprecated and will be removed in Python 3.13. "
     505          "Please use unittest.TestLoader.loadTestsFromTestCase() instead.",
     506          DeprecationWarning, stacklevel=2
     507      )
     508      return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
     509          testCaseClass)
     510  
     511  def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
     512                    suiteClass=suite.TestSuite):
     513      import warnings
     514      warnings.warn(
     515          "unittest.findTestCases() is deprecated and will be removed in Python 3.13. "
     516          "Please use unittest.TestLoader.loadTestsFromModule() instead.",
     517          DeprecationWarning, stacklevel=2
     518      )
     519      return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
     520          module)