(root)/
Python-3.11.7/
Lib/
test/
test_decorators.py
       1  from test import support
       2  import unittest
       3  from types import MethodType
       4  
       5  def funcattrs(**kwds):
       6      def decorate(func):
       7          func.__dict__.update(kwds)
       8          return func
       9      return decorate
      10  
      11  class ESC[4;38;5;81mMiscDecorators (ESC[4;38;5;149mobject):
      12      @staticmethod
      13      def author(name):
      14          def decorate(func):
      15              func.__dict__['author'] = name
      16              return func
      17          return decorate
      18  
      19  # -----------------------------------------------
      20  
      21  class ESC[4;38;5;81mDbcheckError (ESC[4;38;5;149mException):
      22      def __init__(self, exprstr, func, args, kwds):
      23          # A real version of this would set attributes here
      24          Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" %
      25                             (exprstr, func, args, kwds))
      26  
      27  
      28  def dbcheck(exprstr, globals=None, locals=None):
      29      "Decorator to implement debugging assertions"
      30      def decorate(func):
      31          expr = compile(exprstr, "dbcheck-%s" % func.__name__, "eval")
      32          def check(*args, **kwds):
      33              if not eval(expr, globals, locals):
      34                  raise DbcheckError(exprstr, func, args, kwds)
      35              return func(*args, **kwds)
      36          return check
      37      return decorate
      38  
      39  # -----------------------------------------------
      40  
      41  def countcalls(counts):
      42      "Decorator to count calls to a function"
      43      def decorate(func):
      44          func_name = func.__name__
      45          counts[func_name] = 0
      46          def call(*args, **kwds):
      47              counts[func_name] += 1
      48              return func(*args, **kwds)
      49          call.__name__ = func_name
      50          return call
      51      return decorate
      52  
      53  # -----------------------------------------------
      54  
      55  def memoize(func):
      56      saved = {}
      57      def call(*args):
      58          try:
      59              return saved[args]
      60          except KeyError:
      61              res = func(*args)
      62              saved[args] = res
      63              return res
      64          except TypeError:
      65              # Unhashable argument
      66              return func(*args)
      67      call.__name__ = func.__name__
      68      return call
      69  
      70  # -----------------------------------------------
      71  
      72  class ESC[4;38;5;81mTestDecorators(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      73  
      74      def test_single(self):
      75          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject):
      76              @staticmethod
      77              def foo(): return 42
      78          self.assertEqual(C.foo(), 42)
      79          self.assertEqual(C().foo(), 42)
      80  
      81      def check_wrapper_attrs(self, method_wrapper, format_str):
      82          def func(x):
      83              return x
      84          wrapper = method_wrapper(func)
      85  
      86          self.assertIs(wrapper.__func__, func)
      87          self.assertIs(wrapper.__wrapped__, func)
      88  
      89          for attr in ('__module__', '__qualname__', '__name__',
      90                       '__doc__', '__annotations__'):
      91              self.assertIs(getattr(wrapper, attr),
      92                            getattr(func, attr))
      93  
      94          self.assertEqual(repr(wrapper), format_str.format(func))
      95          return wrapper
      96  
      97      def test_staticmethod(self):
      98          wrapper = self.check_wrapper_attrs(staticmethod, '<staticmethod({!r})>')
      99  
     100          # bpo-43682: Static methods are callable since Python 3.10
     101          self.assertEqual(wrapper(1), 1)
     102  
     103      def test_classmethod(self):
     104          wrapper = self.check_wrapper_attrs(classmethod, '<classmethod({!r})>')
     105  
     106          self.assertRaises(TypeError, wrapper, 1)
     107  
     108      def test_dotted(self):
     109          decorators = MiscDecorators()
     110          @decorators.author('Cleese')
     111          def foo(): return 42
     112          self.assertEqual(foo(), 42)
     113          self.assertEqual(foo.author, 'Cleese')
     114  
     115      def test_argforms(self):
     116          # A few tests of argument passing, as we use restricted form
     117          # of expressions for decorators.
     118  
     119          def noteargs(*args, **kwds):
     120              def decorate(func):
     121                  setattr(func, 'dbval', (args, kwds))
     122                  return func
     123              return decorate
     124  
     125          args = ( 'Now', 'is', 'the', 'time' )
     126          kwds = dict(one=1, two=2)
     127          @noteargs(*args, **kwds)
     128          def f1(): return 42
     129          self.assertEqual(f1(), 42)
     130          self.assertEqual(f1.dbval, (args, kwds))
     131  
     132          @noteargs('terry', 'gilliam', eric='idle', john='cleese')
     133          def f2(): return 84
     134          self.assertEqual(f2(), 84)
     135          self.assertEqual(f2.dbval, (('terry', 'gilliam'),
     136                                       dict(eric='idle', john='cleese')))
     137  
     138          @noteargs(1, 2,)
     139          def f3(): pass
     140          self.assertEqual(f3.dbval, ((1, 2), {}))
     141  
     142      def test_dbcheck(self):
     143          @dbcheck('args[1] is not None')
     144          def f(a, b):
     145              return a + b
     146          self.assertEqual(f(1, 2), 3)
     147          self.assertRaises(DbcheckError, f, 1, None)
     148  
     149      def test_memoize(self):
     150          counts = {}
     151  
     152          @memoize
     153          @countcalls(counts)
     154          def double(x):
     155              return x * 2
     156          self.assertEqual(double.__name__, 'double')
     157  
     158          self.assertEqual(counts, dict(double=0))
     159  
     160          # Only the first call with a given argument bumps the call count:
     161          #
     162          self.assertEqual(double(2), 4)
     163          self.assertEqual(counts['double'], 1)
     164          self.assertEqual(double(2), 4)
     165          self.assertEqual(counts['double'], 1)
     166          self.assertEqual(double(3), 6)
     167          self.assertEqual(counts['double'], 2)
     168  
     169          # Unhashable arguments do not get memoized:
     170          #
     171          self.assertEqual(double([10]), [10, 10])
     172          self.assertEqual(counts['double'], 3)
     173          self.assertEqual(double([10]), [10, 10])
     174          self.assertEqual(counts['double'], 4)
     175  
     176      def test_errors(self):
     177  
     178          # Test SyntaxErrors:
     179          for stmt in ("x,", "x, y", "x = y", "pass", "import sys"):
     180              compile(stmt, "test", "exec")  # Sanity check.
     181              with self.assertRaises(SyntaxError):
     182                  compile(f"@{stmt}\ndef f(): pass", "test", "exec")
     183  
     184          # Test TypeErrors that used to be SyntaxErrors:
     185          for expr in ("1.+2j", "[1, 2][-1]", "(1, 2)", "True", "...", "None"):
     186              compile(expr, "test", "eval")  # Sanity check.
     187              with self.assertRaises(TypeError):
     188                  exec(f"@{expr}\ndef f(): pass")
     189  
     190          def unimp(func):
     191              raise NotImplementedError
     192          context = dict(nullval=None, unimp=unimp)
     193  
     194          for expr, exc in [ ("undef", NameError),
     195                             ("nullval", TypeError),
     196                             ("nullval.attr", AttributeError),
     197                             ("unimp", NotImplementedError)]:
     198              codestr = "@%s\ndef f(): pass\nassert f() is None" % expr
     199              code = compile(codestr, "test", "exec")
     200              self.assertRaises(exc, eval, code, context)
     201  
     202      def test_expressions(self):
     203          for expr in (
     204              "(x,)", "(x, y)", "x := y", "(x := y)", "x @y", "(x @ y)", "x[0]",
     205              "w[x].y.z", "w + x - (y + z)", "x(y)()(z)", "[w, x, y][z]", "x.y",
     206          ):
     207              compile(f"@{expr}\ndef f(): pass", "test", "exec")
     208  
     209      def test_double(self):
     210          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject):
     211              @funcattrs(abc=1, xyz="haha")
     212              @funcattrs(booh=42)
     213              def foo(self): return 42
     214          self.assertEqual(C().foo(), 42)
     215          self.assertEqual(C.foo.abc, 1)
     216          self.assertEqual(C.foo.xyz, "haha")
     217          self.assertEqual(C.foo.booh, 42)
     218  
     219      def test_order(self):
     220          # Test that decorators are applied in the proper order to the function
     221          # they are decorating.
     222          def callnum(num):
     223              """Decorator factory that returns a decorator that replaces the
     224              passed-in function with one that returns the value of 'num'"""
     225              def deco(func):
     226                  return lambda: num
     227              return deco
     228          @callnum(2)
     229          @callnum(1)
     230          def foo(): return 42
     231          self.assertEqual(foo(), 2,
     232                              "Application order of decorators is incorrect")
     233  
     234      def test_eval_order(self):
     235          # Evaluating a decorated function involves four steps for each
     236          # decorator-maker (the function that returns a decorator):
     237          #
     238          #    1: Evaluate the decorator-maker name
     239          #    2: Evaluate the decorator-maker arguments (if any)
     240          #    3: Call the decorator-maker to make a decorator
     241          #    4: Call the decorator
     242          #
     243          # When there are multiple decorators, these steps should be
     244          # performed in the above order for each decorator, but we should
     245          # iterate through the decorators in the reverse of the order they
     246          # appear in the source.
     247  
     248          actions = []
     249  
     250          def make_decorator(tag):
     251              actions.append('makedec' + tag)
     252              def decorate(func):
     253                  actions.append('calldec' + tag)
     254                  return func
     255              return decorate
     256  
     257          class ESC[4;38;5;81mNameLookupTracer (ESC[4;38;5;149mobject):
     258              def __init__(self, index):
     259                  self.index = index
     260  
     261              def __getattr__(self, fname):
     262                  if fname == 'make_decorator':
     263                      opname, res = ('evalname', make_decorator)
     264                  elif fname == 'arg':
     265                      opname, res = ('evalargs', str(self.index))
     266                  else:
     267                      assert False, "Unknown attrname %s" % fname
     268                  actions.append('%s%d' % (opname, self.index))
     269                  return res
     270  
     271          c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ])
     272  
     273          expected_actions = [ 'evalname1', 'evalargs1', 'makedec1',
     274                               'evalname2', 'evalargs2', 'makedec2',
     275                               'evalname3', 'evalargs3', 'makedec3',
     276                               'calldec3', 'calldec2', 'calldec1' ]
     277  
     278          actions = []
     279          @c1.make_decorator(c1.arg)
     280          @c2.make_decorator(c2.arg)
     281          @c3.make_decorator(c3.arg)
     282          def foo(): return 42
     283          self.assertEqual(foo(), 42)
     284  
     285          self.assertEqual(actions, expected_actions)
     286  
     287          # Test the equivalence claim in chapter 7 of the reference manual.
     288          #
     289          actions = []
     290          def bar(): return 42
     291          bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar)))
     292          self.assertEqual(bar(), 42)
     293          self.assertEqual(actions, expected_actions)
     294  
     295      def test_wrapped_descriptor_inside_classmethod(self):
     296          class ESC[4;38;5;81mBoundWrapper:
     297              def __init__(self, wrapped):
     298                  self.__wrapped__ = wrapped
     299  
     300              def __call__(self, *args, **kwargs):
     301                  return self.__wrapped__(*args, **kwargs)
     302  
     303          class ESC[4;38;5;81mWrapper:
     304              def __init__(self, wrapped):
     305                  self.__wrapped__ = wrapped
     306  
     307              def __get__(self, instance, owner):
     308                  bound_function = self.__wrapped__.__get__(instance, owner)
     309                  return BoundWrapper(bound_function)
     310  
     311          def decorator(wrapped):
     312              return Wrapper(wrapped)
     313  
     314          class ESC[4;38;5;81mClass:
     315              @decorator
     316              @classmethod
     317              def inner(cls):
     318                  # This should already work.
     319                  return 'spam'
     320  
     321              @classmethod
     322              @decorator
     323              def outer(cls):
     324                  # Raised TypeError with a message saying that the 'Wrapper'
     325                  # object is not callable.
     326                  return 'eggs'
     327  
     328          self.assertEqual(Class.inner(), 'spam')
     329          self.assertEqual(Class.outer(), 'eggs')
     330          self.assertEqual(Class().inner(), 'spam')
     331          self.assertEqual(Class().outer(), 'eggs')
     332  
     333      def test_bound_function_inside_classmethod(self):
     334          class ESC[4;38;5;81mA:
     335              def foo(self, cls):
     336                  return 'spam'
     337  
     338          class ESC[4;38;5;81mB:
     339              bar = classmethod(A().foo)
     340  
     341          self.assertEqual(B.bar(), 'spam')
     342  
     343      def test_wrapped_classmethod_inside_classmethod(self):
     344          class ESC[4;38;5;81mMyClassMethod1:
     345              def __init__(self, func):
     346                  self.func = func
     347  
     348              def __call__(self, cls):
     349                  if hasattr(self.func, '__get__'):
     350                      return self.func.__get__(cls, cls)()
     351                  return self.func(cls)
     352  
     353              def __get__(self, instance, owner=None):
     354                  if owner is None:
     355                      owner = type(instance)
     356                  return MethodType(self, owner)
     357  
     358          class ESC[4;38;5;81mMyClassMethod2:
     359              def __init__(self, func):
     360                  if isinstance(func, classmethod):
     361                      func = func.__func__
     362                  self.func = func
     363  
     364              def __call__(self, cls):
     365                  return self.func(cls)
     366  
     367              def __get__(self, instance, owner=None):
     368                  if owner is None:
     369                      owner = type(instance)
     370                  return MethodType(self, owner)
     371  
     372          for myclassmethod in [MyClassMethod1, MyClassMethod2]:
     373              class ESC[4;38;5;81mA:
     374                  @myclassmethod
     375                  def f1(cls):
     376                      return cls
     377  
     378                  @classmethod
     379                  @myclassmethod
     380                  def f2(cls):
     381                      return cls
     382  
     383                  @myclassmethod
     384                  @classmethod
     385                  def f3(cls):
     386                      return cls
     387  
     388                  @classmethod
     389                  @classmethod
     390                  def f4(cls):
     391                      return cls
     392  
     393                  @myclassmethod
     394                  @MyClassMethod1
     395                  def f5(cls):
     396                      return cls
     397  
     398                  @myclassmethod
     399                  @MyClassMethod2
     400                  def f6(cls):
     401                      return cls
     402  
     403              self.assertIs(A.f1(), A)
     404              self.assertIs(A.f2(), A)
     405              self.assertIs(A.f3(), A)
     406              self.assertIs(A.f4(), A)
     407              self.assertIs(A.f5(), A)
     408              self.assertIs(A.f6(), A)
     409              a = A()
     410              self.assertIs(a.f1(), A)
     411              self.assertIs(a.f2(), A)
     412              self.assertIs(a.f3(), A)
     413              self.assertIs(a.f4(), A)
     414              self.assertIs(a.f5(), A)
     415              self.assertIs(a.f6(), A)
     416  
     417              def f(cls):
     418                  return cls
     419  
     420              self.assertIs(myclassmethod(f).__get__(a)(), A)
     421              self.assertIs(myclassmethod(f).__get__(a, A)(), A)
     422              self.assertIs(myclassmethod(f).__get__(A, A)(), A)
     423              self.assertIs(myclassmethod(f).__get__(A)(), type(A))
     424              self.assertIs(classmethod(f).__get__(a)(), A)
     425              self.assertIs(classmethod(f).__get__(a, A)(), A)
     426              self.assertIs(classmethod(f).__get__(A, A)(), A)
     427              self.assertIs(classmethod(f).__get__(A)(), type(A))
     428  
     429  class ESC[4;38;5;81mTestClassDecorators(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     430  
     431      def test_simple(self):
     432          def plain(x):
     433              x.extra = 'Hello'
     434              return x
     435          @plain
     436          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject): pass
     437          self.assertEqual(C.extra, 'Hello')
     438  
     439      def test_double(self):
     440          def ten(x):
     441              x.extra = 10
     442              return x
     443          def add_five(x):
     444              x.extra += 5
     445              return x
     446  
     447          @add_five
     448          @ten
     449          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject): pass
     450          self.assertEqual(C.extra, 15)
     451  
     452      def test_order(self):
     453          def applied_first(x):
     454              x.extra = 'first'
     455              return x
     456          def applied_second(x):
     457              x.extra = 'second'
     458              return x
     459          @applied_second
     460          @applied_first
     461          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject): pass
     462          self.assertEqual(C.extra, 'second')
     463  
     464  if __name__ == "__main__":
     465      unittest.main()