(root)/
Python-3.12.0/
Lib/
test/
test_decorators.py
       1  import unittest
       2  from types import MethodType
       3  
       4  def funcattrs(**kwds):
       5      def decorate(func):
       6          func.__dict__.update(kwds)
       7          return func
       8      return decorate
       9  
      10  class ESC[4;38;5;81mMiscDecorators (ESC[4;38;5;149mobject):
      11      @staticmethod
      12      def author(name):
      13          def decorate(func):
      14              func.__dict__['author'] = name
      15              return func
      16          return decorate
      17  
      18  # -----------------------------------------------
      19  
      20  class ESC[4;38;5;81mDbcheckError (ESC[4;38;5;149mException):
      21      def __init__(self, exprstr, func, args, kwds):
      22          # A real version of this would set attributes here
      23          Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" %
      24                             (exprstr, func, args, kwds))
      25  
      26  
      27  def dbcheck(exprstr, globals=None, locals=None):
      28      "Decorator to implement debugging assertions"
      29      def decorate(func):
      30          expr = compile(exprstr, "dbcheck-%s" % func.__name__, "eval")
      31          def check(*args, **kwds):
      32              if not eval(expr, globals, locals):
      33                  raise DbcheckError(exprstr, func, args, kwds)
      34              return func(*args, **kwds)
      35          return check
      36      return decorate
      37  
      38  # -----------------------------------------------
      39  
      40  def countcalls(counts):
      41      "Decorator to count calls to a function"
      42      def decorate(func):
      43          func_name = func.__name__
      44          counts[func_name] = 0
      45          def call(*args, **kwds):
      46              counts[func_name] += 1
      47              return func(*args, **kwds)
      48          call.__name__ = func_name
      49          return call
      50      return decorate
      51  
      52  # -----------------------------------------------
      53  
      54  def memoize(func):
      55      saved = {}
      56      def call(*args):
      57          try:
      58              return saved[args]
      59          except KeyError:
      60              res = func(*args)
      61              saved[args] = res
      62              return res
      63          except TypeError:
      64              # Unhashable argument
      65              return func(*args)
      66      call.__name__ = func.__name__
      67      return call
      68  
      69  # -----------------------------------------------
      70  
      71  class ESC[4;38;5;81mTestDecorators(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      72  
      73      def test_single(self):
      74          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject):
      75              @staticmethod
      76              def foo(): return 42
      77          self.assertEqual(C.foo(), 42)
      78          self.assertEqual(C().foo(), 42)
      79  
      80      def check_wrapper_attrs(self, method_wrapper, format_str):
      81          def func(x):
      82              return x
      83          wrapper = method_wrapper(func)
      84  
      85          self.assertIs(wrapper.__func__, func)
      86          self.assertIs(wrapper.__wrapped__, func)
      87  
      88          for attr in ('__module__', '__qualname__', '__name__',
      89                       '__doc__', '__annotations__'):
      90              self.assertIs(getattr(wrapper, attr),
      91                            getattr(func, attr))
      92  
      93          self.assertEqual(repr(wrapper), format_str.format(func))
      94          return wrapper
      95  
      96      def test_staticmethod(self):
      97          wrapper = self.check_wrapper_attrs(staticmethod, '<staticmethod({!r})>')
      98  
      99          # bpo-43682: Static methods are callable since Python 3.10
     100          self.assertEqual(wrapper(1), 1)
     101  
     102      def test_classmethod(self):
     103          wrapper = self.check_wrapper_attrs(classmethod, '<classmethod({!r})>')
     104  
     105          self.assertRaises(TypeError, wrapper, 1)
     106  
     107      def test_dotted(self):
     108          decorators = MiscDecorators()
     109          @decorators.author('Cleese')
     110          def foo(): return 42
     111          self.assertEqual(foo(), 42)
     112          self.assertEqual(foo.author, 'Cleese')
     113  
     114      def test_argforms(self):
     115          # A few tests of argument passing, as we use restricted form
     116          # of expressions for decorators.
     117  
     118          def noteargs(*args, **kwds):
     119              def decorate(func):
     120                  setattr(func, 'dbval', (args, kwds))
     121                  return func
     122              return decorate
     123  
     124          args = ( 'Now', 'is', 'the', 'time' )
     125          kwds = dict(one=1, two=2)
     126          @noteargs(*args, **kwds)
     127          def f1(): return 42
     128          self.assertEqual(f1(), 42)
     129          self.assertEqual(f1.dbval, (args, kwds))
     130  
     131          @noteargs('terry', 'gilliam', eric='idle', john='cleese')
     132          def f2(): return 84
     133          self.assertEqual(f2(), 84)
     134          self.assertEqual(f2.dbval, (('terry', 'gilliam'),
     135                                       dict(eric='idle', john='cleese')))
     136  
     137          @noteargs(1, 2,)
     138          def f3(): pass
     139          self.assertEqual(f3.dbval, ((1, 2), {}))
     140  
     141      def test_dbcheck(self):
     142          @dbcheck('args[1] is not None')
     143          def f(a, b):
     144              return a + b
     145          self.assertEqual(f(1, 2), 3)
     146          self.assertRaises(DbcheckError, f, 1, None)
     147  
     148      def test_memoize(self):
     149          counts = {}
     150  
     151          @memoize
     152          @countcalls(counts)
     153          def double(x):
     154              return x * 2
     155          self.assertEqual(double.__name__, 'double')
     156  
     157          self.assertEqual(counts, dict(double=0))
     158  
     159          # Only the first call with a given argument bumps the call count:
     160          #
     161          self.assertEqual(double(2), 4)
     162          self.assertEqual(counts['double'], 1)
     163          self.assertEqual(double(2), 4)
     164          self.assertEqual(counts['double'], 1)
     165          self.assertEqual(double(3), 6)
     166          self.assertEqual(counts['double'], 2)
     167  
     168          # Unhashable arguments do not get memoized:
     169          #
     170          self.assertEqual(double([10]), [10, 10])
     171          self.assertEqual(counts['double'], 3)
     172          self.assertEqual(double([10]), [10, 10])
     173          self.assertEqual(counts['double'], 4)
     174  
     175      def test_errors(self):
     176  
     177          # Test SyntaxErrors:
     178          for stmt in ("x,", "x, y", "x = y", "pass", "import sys"):
     179              compile(stmt, "test", "exec")  # Sanity check.
     180              with self.assertRaises(SyntaxError):
     181                  compile(f"@{stmt}\ndef f(): pass", "test", "exec")
     182  
     183          # Test TypeErrors that used to be SyntaxErrors:
     184          for expr in ("1.+2j", "[1, 2][-1]", "(1, 2)", "True", "...", "None"):
     185              compile(expr, "test", "eval")  # Sanity check.
     186              with self.assertRaises(TypeError):
     187                  exec(f"@{expr}\ndef f(): pass")
     188  
     189          def unimp(func):
     190              raise NotImplementedError
     191          context = dict(nullval=None, unimp=unimp)
     192  
     193          for expr, exc in [ ("undef", NameError),
     194                             ("nullval", TypeError),
     195                             ("nullval.attr", AttributeError),
     196                             ("unimp", NotImplementedError)]:
     197              codestr = "@%s\ndef f(): pass\nassert f() is None" % expr
     198              code = compile(codestr, "test", "exec")
     199              self.assertRaises(exc, eval, code, context)
     200  
     201      def test_expressions(self):
     202          for expr in (
     203              "(x,)", "(x, y)", "x := y", "(x := y)", "x @y", "(x @ y)", "x[0]",
     204              "w[x].y.z", "w + x - (y + z)", "x(y)()(z)", "[w, x, y][z]", "x.y",
     205          ):
     206              compile(f"@{expr}\ndef f(): pass", "test", "exec")
     207  
     208      def test_double(self):
     209          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject):
     210              @funcattrs(abc=1, xyz="haha")
     211              @funcattrs(booh=42)
     212              def foo(self): return 42
     213          self.assertEqual(C().foo(), 42)
     214          self.assertEqual(C.foo.abc, 1)
     215          self.assertEqual(C.foo.xyz, "haha")
     216          self.assertEqual(C.foo.booh, 42)
     217  
     218      def test_order(self):
     219          # Test that decorators are applied in the proper order to the function
     220          # they are decorating.
     221          def callnum(num):
     222              """Decorator factory that returns a decorator that replaces the
     223              passed-in function with one that returns the value of 'num'"""
     224              def deco(func):
     225                  return lambda: num
     226              return deco
     227          @callnum(2)
     228          @callnum(1)
     229          def foo(): return 42
     230          self.assertEqual(foo(), 2,
     231                              "Application order of decorators is incorrect")
     232  
     233      def test_eval_order(self):
     234          # Evaluating a decorated function involves four steps for each
     235          # decorator-maker (the function that returns a decorator):
     236          #
     237          #    1: Evaluate the decorator-maker name
     238          #    2: Evaluate the decorator-maker arguments (if any)
     239          #    3: Call the decorator-maker to make a decorator
     240          #    4: Call the decorator
     241          #
     242          # When there are multiple decorators, these steps should be
     243          # performed in the above order for each decorator, but we should
     244          # iterate through the decorators in the reverse of the order they
     245          # appear in the source.
     246  
     247          actions = []
     248  
     249          def make_decorator(tag):
     250              actions.append('makedec' + tag)
     251              def decorate(func):
     252                  actions.append('calldec' + tag)
     253                  return func
     254              return decorate
     255  
     256          class ESC[4;38;5;81mNameLookupTracer (ESC[4;38;5;149mobject):
     257              def __init__(self, index):
     258                  self.index = index
     259  
     260              def __getattr__(self, fname):
     261                  if fname == 'make_decorator':
     262                      opname, res = ('evalname', make_decorator)
     263                  elif fname == 'arg':
     264                      opname, res = ('evalargs', str(self.index))
     265                  else:
     266                      assert False, "Unknown attrname %s" % fname
     267                  actions.append('%s%d' % (opname, self.index))
     268                  return res
     269  
     270          c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ])
     271  
     272          expected_actions = [ 'evalname1', 'evalargs1', 'makedec1',
     273                               'evalname2', 'evalargs2', 'makedec2',
     274                               'evalname3', 'evalargs3', 'makedec3',
     275                               'calldec3', 'calldec2', 'calldec1' ]
     276  
     277          actions = []
     278          @c1.make_decorator(c1.arg)
     279          @c2.make_decorator(c2.arg)
     280          @c3.make_decorator(c3.arg)
     281          def foo(): return 42
     282          self.assertEqual(foo(), 42)
     283  
     284          self.assertEqual(actions, expected_actions)
     285  
     286          # Test the equivalence claim in chapter 7 of the reference manual.
     287          #
     288          actions = []
     289          def bar(): return 42
     290          bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar)))
     291          self.assertEqual(bar(), 42)
     292          self.assertEqual(actions, expected_actions)
     293  
     294      def test_wrapped_descriptor_inside_classmethod(self):
     295          class ESC[4;38;5;81mBoundWrapper:
     296              def __init__(self, wrapped):
     297                  self.__wrapped__ = wrapped
     298  
     299              def __call__(self, *args, **kwargs):
     300                  return self.__wrapped__(*args, **kwargs)
     301  
     302          class ESC[4;38;5;81mWrapper:
     303              def __init__(self, wrapped):
     304                  self.__wrapped__ = wrapped
     305  
     306              def __get__(self, instance, owner):
     307                  bound_function = self.__wrapped__.__get__(instance, owner)
     308                  return BoundWrapper(bound_function)
     309  
     310          def decorator(wrapped):
     311              return Wrapper(wrapped)
     312  
     313          class ESC[4;38;5;81mClass:
     314              @decorator
     315              @classmethod
     316              def inner(cls):
     317                  # This should already work.
     318                  return 'spam'
     319  
     320              @classmethod
     321              @decorator
     322              def outer(cls):
     323                  # Raised TypeError with a message saying that the 'Wrapper'
     324                  # object is not callable.
     325                  return 'eggs'
     326  
     327          self.assertEqual(Class.inner(), 'spam')
     328          self.assertEqual(Class.outer(), 'eggs')
     329          self.assertEqual(Class().inner(), 'spam')
     330          self.assertEqual(Class().outer(), 'eggs')
     331  
     332      def test_bound_function_inside_classmethod(self):
     333          class ESC[4;38;5;81mA:
     334              def foo(self, cls):
     335                  return 'spam'
     336  
     337          class ESC[4;38;5;81mB:
     338              bar = classmethod(A().foo)
     339  
     340          self.assertEqual(B.bar(), 'spam')
     341  
     342      def test_wrapped_classmethod_inside_classmethod(self):
     343          class ESC[4;38;5;81mMyClassMethod1:
     344              def __init__(self, func):
     345                  self.func = func
     346  
     347              def __call__(self, cls):
     348                  if hasattr(self.func, '__get__'):
     349                      return self.func.__get__(cls, cls)()
     350                  return self.func(cls)
     351  
     352              def __get__(self, instance, owner=None):
     353                  if owner is None:
     354                      owner = type(instance)
     355                  return MethodType(self, owner)
     356  
     357          class ESC[4;38;5;81mMyClassMethod2:
     358              def __init__(self, func):
     359                  if isinstance(func, classmethod):
     360                      func = func.__func__
     361                  self.func = func
     362  
     363              def __call__(self, cls):
     364                  return self.func(cls)
     365  
     366              def __get__(self, instance, owner=None):
     367                  if owner is None:
     368                      owner = type(instance)
     369                  return MethodType(self, owner)
     370  
     371          for myclassmethod in [MyClassMethod1, MyClassMethod2]:
     372              class ESC[4;38;5;81mA:
     373                  @myclassmethod
     374                  def f1(cls):
     375                      return cls
     376  
     377                  @classmethod
     378                  @myclassmethod
     379                  def f2(cls):
     380                      return cls
     381  
     382                  @myclassmethod
     383                  @classmethod
     384                  def f3(cls):
     385                      return cls
     386  
     387                  @classmethod
     388                  @classmethod
     389                  def f4(cls):
     390                      return cls
     391  
     392                  @myclassmethod
     393                  @MyClassMethod1
     394                  def f5(cls):
     395                      return cls
     396  
     397                  @myclassmethod
     398                  @MyClassMethod2
     399                  def f6(cls):
     400                      return cls
     401  
     402              self.assertIs(A.f1(), A)
     403              self.assertIs(A.f2(), A)
     404              self.assertIs(A.f3(), A)
     405              self.assertIs(A.f4(), A)
     406              self.assertIs(A.f5(), A)
     407              self.assertIs(A.f6(), A)
     408              a = A()
     409              self.assertIs(a.f1(), A)
     410              self.assertIs(a.f2(), A)
     411              self.assertIs(a.f3(), A)
     412              self.assertIs(a.f4(), A)
     413              self.assertIs(a.f5(), A)
     414              self.assertIs(a.f6(), A)
     415  
     416              def f(cls):
     417                  return cls
     418  
     419              self.assertIs(myclassmethod(f).__get__(a)(), A)
     420              self.assertIs(myclassmethod(f).__get__(a, A)(), A)
     421              self.assertIs(myclassmethod(f).__get__(A, A)(), A)
     422              self.assertIs(myclassmethod(f).__get__(A)(), type(A))
     423              self.assertIs(classmethod(f).__get__(a)(), A)
     424              self.assertIs(classmethod(f).__get__(a, A)(), A)
     425              self.assertIs(classmethod(f).__get__(A, A)(), A)
     426              self.assertIs(classmethod(f).__get__(A)(), type(A))
     427  
     428  class ESC[4;38;5;81mTestClassDecorators(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     429  
     430      def test_simple(self):
     431          def plain(x):
     432              x.extra = 'Hello'
     433              return x
     434          @plain
     435          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject): pass
     436          self.assertEqual(C.extra, 'Hello')
     437  
     438      def test_double(self):
     439          def ten(x):
     440              x.extra = 10
     441              return x
     442          def add_five(x):
     443              x.extra += 5
     444              return x
     445  
     446          @add_five
     447          @ten
     448          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject): pass
     449          self.assertEqual(C.extra, 15)
     450  
     451      def test_order(self):
     452          def applied_first(x):
     453              x.extra = 'first'
     454              return x
     455          def applied_second(x):
     456              x.extra = 'second'
     457              return x
     458          @applied_second
     459          @applied_first
     460          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject): pass
     461          self.assertEqual(C.extra, 'second')
     462  
     463  if __name__ == "__main__":
     464      unittest.main()