1 """Helper to provide extensibility for pickle.
2
3 This is only useful to add pickle support for extension types defined in
4 C, not for instances of user-defined classes.
5 """
6
7 __all__ = ["pickle", "constructor",
8 "add_extension", "remove_extension", "clear_extension_cache"]
9
10 dispatch_table = {}
11
12 def pickle(ob_type, pickle_function, constructor_ob=None):
13 if not callable(pickle_function):
14 raise TypeError("reduction functions must be callable")
15 dispatch_table[ob_type] = pickle_function
16
17 # The constructor_ob function is a vestige of safe for unpickling.
18 # There is no reason for the caller to pass it anymore.
19 if constructor_ob is not None:
20 constructor(constructor_ob)
21
22 def constructor(object):
23 if not callable(object):
24 raise TypeError("constructors must be callable")
25
26 # Example: provide pickling support for complex numbers.
27
28 try:
29 complex
30 except NameError:
31 pass
32 else:
33
34 def pickle_complex(c):
35 return complex, (c.real, c.imag)
36
37 pickle(complex, pickle_complex, complex)
38
39 def pickle_union(obj):
40 import functools, operator
41 return functools.reduce, (operator.or_, obj.__args__)
42
43 pickle(type(int | str), pickle_union)
44
45 # Support for pickling new-style objects
46
47 def _reconstructor(cls, base, state):
48 if base is object:
49 obj = object.__new__(cls)
50 else:
51 obj = base.__new__(cls, state)
52 if base.__init__ != object.__init__:
53 base.__init__(obj, state)
54 return obj
55
56 _HEAPTYPE = 1<<9
57 _new_type = type(int.__new__)
58
59 # Python code for object.__reduce_ex__ for protocols 0 and 1
60
61 def _reduce_ex(self, proto):
62 assert proto < 2
63 cls = self.__class__
64 for base in cls.__mro__:
65 if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE:
66 break
67 new = base.__new__
68 if isinstance(new, _new_type) and new.__self__ is base:
69 break
70 else:
71 base = object # not really reachable
72 if base is object:
73 state = None
74 else:
75 if base is cls:
76 raise TypeError(f"cannot pickle {cls.__name__!r} object")
77 state = base(self)
78 args = (cls, base, state)
79 try:
80 getstate = self.__getstate__
81 except AttributeError:
82 if getattr(self, "__slots__", None):
83 raise TypeError(f"cannot pickle {cls.__name__!r} object: "
84 f"a class that defines __slots__ without "
85 f"defining __getstate__ cannot be pickled "
86 f"with protocol {proto}") from None
87 try:
88 dict = self.__dict__
89 except AttributeError:
90 dict = None
91 else:
92 if (type(self).__getstate__ is object.__getstate__ and
93 getattr(self, "__slots__", None)):
94 raise TypeError("a class that defines __slots__ without "
95 "defining __getstate__ cannot be pickled")
96 dict = getstate()
97 if dict:
98 return _reconstructor, args, dict
99 else:
100 return _reconstructor, args
101
102 # Helper for __reduce_ex__ protocol 2
103
104 def __newobj__(cls, *args):
105 return cls.__new__(cls, *args)
106
107 def __newobj_ex__(cls, args, kwargs):
108 """Used by pickle protocol 4, instead of __newobj__ to allow classes with
109 keyword-only arguments to be pickled correctly.
110 """
111 return cls.__new__(cls, *args, **kwargs)
112
113 def _slotnames(cls):
114 """Return a list of slot names for a given class.
115
116 This needs to find slots defined by the class and its bases, so we
117 can't simply return the __slots__ attribute. We must walk down
118 the Method Resolution Order and concatenate the __slots__ of each
119 class found there. (This assumes classes don't modify their
120 __slots__ attribute to misrepresent their slots after the class is
121 defined.)
122 """
123
124 # Get the value from a cache in the class if possible
125 names = cls.__dict__.get("__slotnames__")
126 if names is not None:
127 return names
128
129 # Not cached -- calculate the value
130 names = []
131 if not hasattr(cls, "__slots__"):
132 # This class has no slots
133 pass
134 else:
135 # Slots found -- gather slot names from all base classes
136 for c in cls.__mro__:
137 if "__slots__" in c.__dict__:
138 slots = c.__dict__['__slots__']
139 # if class has a single slot, it can be given as a string
140 if isinstance(slots, str):
141 slots = (slots,)
142 for name in slots:
143 # special descriptors
144 if name in ("__dict__", "__weakref__"):
145 continue
146 # mangled names
147 elif name.startswith('__') and not name.endswith('__'):
148 stripped = c.__name__.lstrip('_')
149 if stripped:
150 names.append('_%s%s' % (stripped, name))
151 else:
152 names.append(name)
153 else:
154 names.append(name)
155
156 # Cache the outcome in the class if at all possible
157 try:
158 cls.__slotnames__ = names
159 except:
160 pass # But don't die if we can't
161
162 return names
163
164 # A registry of extension codes. This is an ad-hoc compression
165 # mechanism. Whenever a global reference to <module>, <name> is about
166 # to be pickled, the (<module>, <name>) tuple is looked up here to see
167 # if it is a registered extension code for it. Extension codes are
168 # universal, so that the meaning of a pickle does not depend on
169 # context. (There are also some codes reserved for local use that
170 # don't have this restriction.) Codes are positive ints; 0 is
171 # reserved.
172
173 _extension_registry = {} # key -> code
174 _inverted_registry = {} # code -> key
175 _extension_cache = {} # code -> object
176 # Don't ever rebind those names: pickling grabs a reference to them when
177 # it's initialized, and won't see a rebinding.
178
179 def add_extension(module, name, code):
180 """Register an extension code."""
181 code = int(code)
182 if not 1 <= code <= 0x7fffffff:
183 raise ValueError("code out of range")
184 key = (module, name)
185 if (_extension_registry.get(key) == code and
186 _inverted_registry.get(code) == key):
187 return # Redundant registrations are benign
188 if key in _extension_registry:
189 raise ValueError("key %s is already registered with code %s" %
190 (key, _extension_registry[key]))
191 if code in _inverted_registry:
192 raise ValueError("code %s is already in use for key %s" %
193 (code, _inverted_registry[code]))
194 _extension_registry[key] = code
195 _inverted_registry[code] = key
196
197 def remove_extension(module, name, code):
198 """Unregister an extension code. For testing only."""
199 key = (module, name)
200 if (_extension_registry.get(key) != code or
201 _inverted_registry.get(code) != key):
202 raise ValueError("key %s is not registered with code %s" %
203 (key, code))
204 del _extension_registry[key]
205 del _inverted_registry[code]
206 if code in _extension_cache:
207 del _extension_cache[code]
208
209 def clear_extension_cache():
210 _extension_cache.clear()
211
212 # Standard extension code assignments
213
214 # Reserved ranges
215
216 # First Last Count Purpose
217 # 1 127 127 Reserved for Python standard library
218 # 128 191 64 Reserved for Zope
219 # 192 239 48 Reserved for 3rd parties
220 # 240 255 16 Reserved for private use (will never be assigned)
221 # 256 Inf Inf Reserved for future assignment
222
223 # Extension codes are assigned by the Python Software Foundation.