(root)/
Python-3.12.0/
Lib/
unittest/
loader.py
       1  """Loading unittests."""
       2  
       3  import os
       4  import re
       5  import sys
       6  import traceback
       7  import types
       8  import functools
       9  
      10  from fnmatch import fnmatch, fnmatchcase
      11  
      12  from . import case, suite, util
      13  
      14  __unittest = True
      15  
      16  # what about .pyc (etc)
      17  # we would need to avoid loading the same tests multiple times
      18  # from '.py', *and* '.pyc'
      19  VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
      20  
      21  
      22  class ESC[4;38;5;81m_FailedTest(ESC[4;38;5;149mcaseESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      23      _testMethodName = None
      24  
      25      def __init__(self, method_name, exception):
      26          self._exception = exception
      27          super(_FailedTest, self).__init__(method_name)
      28  
      29      def __getattr__(self, name):
      30          if name != self._testMethodName:
      31              return super(_FailedTest, self).__getattr__(name)
      32          def testFailure():
      33              raise self._exception
      34          return testFailure
      35  
      36  
      37  def _make_failed_import_test(name, suiteClass):
      38      message = 'Failed to import test module: %s\n%s' % (
      39          name, traceback.format_exc())
      40      return _make_failed_test(name, ImportError(message), suiteClass, message)
      41  
      42  def _make_failed_load_tests(name, exception, suiteClass):
      43      message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
      44      return _make_failed_test(
      45          name, exception, suiteClass, message)
      46  
      47  def _make_failed_test(methodname, exception, suiteClass, message):
      48      test = _FailedTest(methodname, exception)
      49      return suiteClass((test,)), message
      50  
      51  def _make_skipped_test(methodname, exception, suiteClass):
      52      @case.skip(str(exception))
      53      def testSkipped(self):
      54          pass
      55      attrs = {methodname: testSkipped}
      56      TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
      57      return suiteClass((TestClass(methodname),))
      58  
      59  def _splitext(path):
      60      return os.path.splitext(path)[0]
      61  
      62  
      63  class ESC[4;38;5;81mTestLoader(ESC[4;38;5;149mobject):
      64      """
      65      This class is responsible for loading tests according to various criteria
      66      and returning them wrapped in a TestSuite
      67      """
      68      testMethodPrefix = 'test'
      69      sortTestMethodsUsing = staticmethod(util.three_way_cmp)
      70      testNamePatterns = None
      71      suiteClass = suite.TestSuite
      72      _top_level_dir = None
      73  
      74      def __init__(self):
      75          super(TestLoader, self).__init__()
      76          self.errors = []
      77          # Tracks packages which we have called into via load_tests, to
      78          # avoid infinite re-entrancy.
      79          self._loading_packages = set()
      80  
      81      def loadTestsFromTestCase(self, testCaseClass):
      82          """Return a suite of all test cases contained in testCaseClass"""
      83          if issubclass(testCaseClass, suite.TestSuite):
      84              raise TypeError("Test cases should not be derived from "
      85                              "TestSuite. Maybe you meant to derive from "
      86                              "TestCase?")
      87          testCaseNames = self.getTestCaseNames(testCaseClass)
      88          if not testCaseNames and hasattr(testCaseClass, 'runTest'):
      89              testCaseNames = ['runTest']
      90          loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
      91          return loaded_suite
      92  
      93      def loadTestsFromModule(self, module, *, pattern=None):
      94          """Return a suite of all test cases contained in the given module"""
      95          tests = []
      96          for name in dir(module):
      97              obj = getattr(module, name)
      98              if isinstance(obj, type) and issubclass(obj, case.TestCase):
      99                  tests.append(self.loadTestsFromTestCase(obj))
     100  
     101          load_tests = getattr(module, 'load_tests', None)
     102          tests = self.suiteClass(tests)
     103          if load_tests is not None:
     104              try:
     105                  return load_tests(self, tests, pattern)
     106              except Exception as e:
     107                  error_case, error_message = _make_failed_load_tests(
     108                      module.__name__, e, self.suiteClass)
     109                  self.errors.append(error_message)
     110                  return error_case
     111          return tests
     112  
     113      def loadTestsFromName(self, name, module=None):
     114          """Return a suite of all test cases given a string specifier.
     115  
     116          The name may resolve either to a module, a test case class, a
     117          test method within a test case class, or a callable object which
     118          returns a TestCase or TestSuite instance.
     119  
     120          The method optionally resolves the names relative to a given module.
     121          """
     122          parts = name.split('.')
     123          error_case, error_message = None, None
     124          if module is None:
     125              parts_copy = parts[:]
     126              while parts_copy:
     127                  try:
     128                      module_name = '.'.join(parts_copy)
     129                      module = __import__(module_name)
     130                      break
     131                  except ImportError:
     132                      next_attribute = parts_copy.pop()
     133                      # Last error so we can give it to the user if needed.
     134                      error_case, error_message = _make_failed_import_test(
     135                          next_attribute, self.suiteClass)
     136                      if not parts_copy:
     137                          # Even the top level import failed: report that error.
     138                          self.errors.append(error_message)
     139                          return error_case
     140              parts = parts[1:]
     141          obj = module
     142          for part in parts:
     143              try:
     144                  parent, obj = obj, getattr(obj, part)
     145              except AttributeError as e:
     146                  # We can't traverse some part of the name.
     147                  if (getattr(obj, '__path__', None) is not None
     148                      and error_case is not None):
     149                      # This is a package (no __path__ per importlib docs), and we
     150                      # encountered an error importing something. We cannot tell
     151                      # the difference between package.WrongNameTestClass and
     152                      # package.wrong_module_name so we just report the
     153                      # ImportError - it is more informative.
     154                      self.errors.append(error_message)
     155                      return error_case
     156                  else:
     157                      # Otherwise, we signal that an AttributeError has occurred.
     158                      error_case, error_message = _make_failed_test(
     159                          part, e, self.suiteClass,
     160                          'Failed to access attribute:\n%s' % (
     161                              traceback.format_exc(),))
     162                      self.errors.append(error_message)
     163                      return error_case
     164  
     165          if isinstance(obj, types.ModuleType):
     166              return self.loadTestsFromModule(obj)
     167          elif isinstance(obj, type) and issubclass(obj, case.TestCase):
     168              return self.loadTestsFromTestCase(obj)
     169          elif (isinstance(obj, types.FunctionType) and
     170                isinstance(parent, type) and
     171                issubclass(parent, case.TestCase)):
     172              name = parts[-1]
     173              inst = parent(name)
     174              # static methods follow a different path
     175              if not isinstance(getattr(inst, name), types.FunctionType):
     176                  return self.suiteClass([inst])
     177          elif isinstance(obj, suite.TestSuite):
     178              return obj
     179          if callable(obj):
     180              test = obj()
     181              if isinstance(test, suite.TestSuite):
     182                  return test
     183              elif isinstance(test, case.TestCase):
     184                  return self.suiteClass([test])
     185              else:
     186                  raise TypeError("calling %s returned %s, not a test" %
     187                                  (obj, test))
     188          else:
     189              raise TypeError("don't know how to make test from: %s" % obj)
     190  
     191      def loadTestsFromNames(self, names, module=None):
     192          """Return a suite of all test cases found using the given sequence
     193          of string specifiers. See 'loadTestsFromName()'.
     194          """
     195          suites = [self.loadTestsFromName(name, module) for name in names]
     196          return self.suiteClass(suites)
     197  
     198      def getTestCaseNames(self, testCaseClass):
     199          """Return a sorted sequence of method names found within testCaseClass
     200          """
     201          def shouldIncludeMethod(attrname):
     202              if not attrname.startswith(self.testMethodPrefix):
     203                  return False
     204              testFunc = getattr(testCaseClass, attrname)
     205              if not callable(testFunc):
     206                  return False
     207              fullName = f'%s.%s.%s' % (
     208                  testCaseClass.__module__, testCaseClass.__qualname__, attrname
     209              )
     210              return self.testNamePatterns is None or \
     211                  any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
     212          testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
     213          if self.sortTestMethodsUsing:
     214              testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
     215          return testFnNames
     216  
     217      def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
     218          """Find and return all test modules from the specified start
     219          directory, recursing into subdirectories to find them and return all
     220          tests found within them. Only test files that match the pattern will
     221          be loaded. (Using shell style pattern matching.)
     222  
     223          All test modules must be importable from the top level of the project.
     224          If the start directory is not the top level directory then the top
     225          level directory must be specified separately.
     226  
     227          If a test package name (directory with '__init__.py') matches the
     228          pattern then the package will be checked for a 'load_tests' function. If
     229          this exists then it will be called with (loader, tests, pattern) unless
     230          the package has already had load_tests called from the same discovery
     231          invocation, in which case the package module object is not scanned for
     232          tests - this ensures that when a package uses discover to further
     233          discover child tests that infinite recursion does not happen.
     234  
     235          If load_tests exists then discovery does *not* recurse into the package,
     236          load_tests is responsible for loading all tests in the package.
     237  
     238          The pattern is deliberately not stored as a loader attribute so that
     239          packages can continue discovery themselves. top_level_dir is stored so
     240          load_tests does not need to pass this argument in to loader.discover().
     241  
     242          Paths are sorted before being imported to ensure reproducible execution
     243          order even on filesystems with non-alphabetical ordering like ext3/4.
     244          """
     245          set_implicit_top = False
     246          if top_level_dir is None and self._top_level_dir is not None:
     247              # make top_level_dir optional if called from load_tests in a package
     248              top_level_dir = self._top_level_dir
     249          elif top_level_dir is None:
     250              set_implicit_top = True
     251              top_level_dir = start_dir
     252  
     253          top_level_dir = os.path.abspath(top_level_dir)
     254  
     255          if not top_level_dir in sys.path:
     256              # all test modules must be importable from the top level directory
     257              # should we *unconditionally* put the start directory in first
     258              # in sys.path to minimise likelihood of conflicts between installed
     259              # modules and development versions?
     260              sys.path.insert(0, top_level_dir)
     261          self._top_level_dir = top_level_dir
     262  
     263          is_not_importable = False
     264          if os.path.isdir(os.path.abspath(start_dir)):
     265              start_dir = os.path.abspath(start_dir)
     266              if start_dir != top_level_dir:
     267                  is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
     268          else:
     269              # support for discovery from dotted module names
     270              try:
     271                  __import__(start_dir)
     272              except ImportError:
     273                  is_not_importable = True
     274              else:
     275                  the_module = sys.modules[start_dir]
     276                  top_part = start_dir.split('.')[0]
     277                  try:
     278                      start_dir = os.path.abspath(
     279                          os.path.dirname((the_module.__file__)))
     280                  except AttributeError:
     281                      if the_module.__name__ in sys.builtin_module_names:
     282                          # builtin module
     283                          raise TypeError('Can not use builtin modules '
     284                                          'as dotted module names') from None
     285                      else:
     286                          raise TypeError(
     287                              f"don't know how to discover from {the_module!r}"
     288                              ) from None
     289  
     290                  if set_implicit_top:
     291                      self._top_level_dir = self._get_directory_containing_module(top_part)
     292                      sys.path.remove(top_level_dir)
     293  
     294          if is_not_importable:
     295              raise ImportError('Start directory is not importable: %r' % start_dir)
     296  
     297          tests = list(self._find_tests(start_dir, pattern))
     298          return self.suiteClass(tests)
     299  
     300      def _get_directory_containing_module(self, module_name):
     301          module = sys.modules[module_name]
     302          full_path = os.path.abspath(module.__file__)
     303  
     304          if os.path.basename(full_path).lower().startswith('__init__.py'):
     305              return os.path.dirname(os.path.dirname(full_path))
     306          else:
     307              # here we have been given a module rather than a package - so
     308              # all we can do is search the *same* directory the module is in
     309              # should an exception be raised instead
     310              return os.path.dirname(full_path)
     311  
     312      def _get_name_from_path(self, path):
     313          if path == self._top_level_dir:
     314              return '.'
     315          path = _splitext(os.path.normpath(path))
     316  
     317          _relpath = os.path.relpath(path, self._top_level_dir)
     318          assert not os.path.isabs(_relpath), "Path must be within the project"
     319          assert not _relpath.startswith('..'), "Path must be within the project"
     320  
     321          name = _relpath.replace(os.path.sep, '.')
     322          return name
     323  
     324      def _get_module_from_name(self, name):
     325          __import__(name)
     326          return sys.modules[name]
     327  
     328      def _match_path(self, path, full_path, pattern):
     329          # override this method to use alternative matching strategy
     330          return fnmatch(path, pattern)
     331  
     332      def _find_tests(self, start_dir, pattern):
     333          """Used by discovery. Yields test suites it loads."""
     334          # Handle the __init__ in this package
     335          name = self._get_name_from_path(start_dir)
     336          # name is '.' when start_dir == top_level_dir (and top_level_dir is by
     337          # definition not a package).
     338          if name != '.' and name not in self._loading_packages:
     339              # name is in self._loading_packages while we have called into
     340              # loadTestsFromModule with name.
     341              tests, should_recurse = self._find_test_path(start_dir, pattern)
     342              if tests is not None:
     343                  yield tests
     344              if not should_recurse:
     345                  # Either an error occurred, or load_tests was used by the
     346                  # package.
     347                  return
     348          # Handle the contents.
     349          paths = sorted(os.listdir(start_dir))
     350          for path in paths:
     351              full_path = os.path.join(start_dir, path)
     352              tests, should_recurse = self._find_test_path(full_path, pattern)
     353              if tests is not None:
     354                  yield tests
     355              if should_recurse:
     356                  # we found a package that didn't use load_tests.
     357                  name = self._get_name_from_path(full_path)
     358                  self._loading_packages.add(name)
     359                  try:
     360                      yield from self._find_tests(full_path, pattern)
     361                  finally:
     362                      self._loading_packages.discard(name)
     363  
     364      def _find_test_path(self, full_path, pattern):
     365          """Used by discovery.
     366  
     367          Loads tests from a single file, or a directories' __init__.py when
     368          passed the directory.
     369  
     370          Returns a tuple (None_or_tests_from_file, should_recurse).
     371          """
     372          basename = os.path.basename(full_path)
     373          if os.path.isfile(full_path):
     374              if not VALID_MODULE_NAME.match(basename):
     375                  # valid Python identifiers only
     376                  return None, False
     377              if not self._match_path(basename, full_path, pattern):
     378                  return None, False
     379              # if the test file matches, load it
     380              name = self._get_name_from_path(full_path)
     381              try:
     382                  module = self._get_module_from_name(name)
     383              except case.SkipTest as e:
     384                  return _make_skipped_test(name, e, self.suiteClass), False
     385              except:
     386                  error_case, error_message = \
     387                      _make_failed_import_test(name, self.suiteClass)
     388                  self.errors.append(error_message)
     389                  return error_case, False
     390              else:
     391                  mod_file = os.path.abspath(
     392                      getattr(module, '__file__', full_path))
     393                  realpath = _splitext(
     394                      os.path.realpath(mod_file))
     395                  fullpath_noext = _splitext(
     396                      os.path.realpath(full_path))
     397                  if realpath.lower() != fullpath_noext.lower():
     398                      module_dir = os.path.dirname(realpath)
     399                      mod_name = _splitext(
     400                          os.path.basename(full_path))
     401                      expected_dir = os.path.dirname(full_path)
     402                      msg = ("%r module incorrectly imported from %r. Expected "
     403                             "%r. Is this module globally installed?")
     404                      raise ImportError(
     405                          msg % (mod_name, module_dir, expected_dir))
     406                  return self.loadTestsFromModule(module, pattern=pattern), False
     407          elif os.path.isdir(full_path):
     408              if not os.path.isfile(os.path.join(full_path, '__init__.py')):
     409                  return None, False
     410  
     411              load_tests = None
     412              tests = None
     413              name = self._get_name_from_path(full_path)
     414              try:
     415                  package = self._get_module_from_name(name)
     416              except case.SkipTest as e:
     417                  return _make_skipped_test(name, e, self.suiteClass), False
     418              except:
     419                  error_case, error_message = \
     420                      _make_failed_import_test(name, self.suiteClass)
     421                  self.errors.append(error_message)
     422                  return error_case, False
     423              else:
     424                  load_tests = getattr(package, 'load_tests', None)
     425                  # Mark this package as being in load_tests (possibly ;))
     426                  self._loading_packages.add(name)
     427                  try:
     428                      tests = self.loadTestsFromModule(package, pattern=pattern)
     429                      if load_tests is not None:
     430                          # loadTestsFromModule(package) has loaded tests for us.
     431                          return tests, False
     432                      return tests, True
     433                  finally:
     434                      self._loading_packages.discard(name)
     435          else:
     436              return None, False
     437  
     438  
     439  defaultTestLoader = TestLoader()
     440  
     441  
     442  # These functions are considered obsolete for long time.
     443  # They will be removed in Python 3.13.
     444  
     445  def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None):
     446      loader = TestLoader()
     447      loader.sortTestMethodsUsing = sortUsing
     448      loader.testMethodPrefix = prefix
     449      loader.testNamePatterns = testNamePatterns
     450      if suiteClass:
     451          loader.suiteClass = suiteClass
     452      return loader
     453  
     454  def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None):
     455      import warnings
     456      warnings.warn(
     457          "unittest.getTestCaseNames() is deprecated and will be removed in Python 3.13. "
     458          "Please use unittest.TestLoader.getTestCaseNames() instead.",
     459          DeprecationWarning, stacklevel=2
     460      )
     461      return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)
     462  
     463  def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
     464                suiteClass=suite.TestSuite):
     465      import warnings
     466      warnings.warn(
     467          "unittest.makeSuite() is deprecated and will be removed in Python 3.13. "
     468          "Please use unittest.TestLoader.loadTestsFromTestCase() instead.",
     469          DeprecationWarning, stacklevel=2
     470      )
     471      return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
     472          testCaseClass)
     473  
     474  def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
     475                    suiteClass=suite.TestSuite):
     476      import warnings
     477      warnings.warn(
     478          "unittest.findTestCases() is deprecated and will be removed in Python 3.13. "
     479          "Please use unittest.TestLoader.loadTestsFromModule() instead.",
     480          DeprecationWarning, stacklevel=2
     481      )
     482      return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
     483          module)