1 # Implementat marshal.loads() in pure Python
2
3 import ast
4
5 from typing import Any, Tuple
6
7
8 class ESC[4;38;5;81mType:
9 # Adapted from marshal.c
10 NULL = ord('0')
11 NONE = ord('N')
12 FALSE = ord('F')
13 TRUE = ord('T')
14 STOPITER = ord('S')
15 ELLIPSIS = ord('.')
16 INT = ord('i')
17 INT64 = ord('I')
18 FLOAT = ord('f')
19 BINARY_FLOAT = ord('g')
20 COMPLEX = ord('x')
21 BINARY_COMPLEX = ord('y')
22 LONG = ord('l')
23 STRING = ord('s')
24 INTERNED = ord('t')
25 REF = ord('r')
26 TUPLE = ord('(')
27 LIST = ord('[')
28 DICT = ord('{')
29 CODE = ord('c')
30 UNICODE = ord('u')
31 UNKNOWN = ord('?')
32 SET = ord('<')
33 FROZENSET = ord('>')
34 ASCII = ord('a')
35 ASCII_INTERNED = ord('A')
36 SMALL_TUPLE = ord(')')
37 SHORT_ASCII = ord('z')
38 SHORT_ASCII_INTERNED = ord('Z')
39
40
41 FLAG_REF = 0x80 # with a type, add obj to index
42
43 NULL = object() # marker
44
45 # Cell kinds
46 CO_FAST_LOCAL = 0x20
47 CO_FAST_CELL = 0x40
48 CO_FAST_FREE = 0x80
49
50
51 class ESC[4;38;5;81mCode:
52 def __init__(self, **kwds: Any):
53 self.__dict__.update(kwds)
54
55 def __repr__(self) -> str:
56 return f"Code(**{self.__dict__})"
57
58 co_localsplusnames: Tuple[str]
59 co_localspluskinds: Tuple[int]
60
61 def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]:
62 varnames: list[str] = []
63 for name, kind in zip(self.co_localsplusnames,
64 self.co_localspluskinds):
65 if kind & select_kind:
66 varnames.append(name)
67 return tuple(varnames)
68
69 @property
70 def co_varnames(self) -> Tuple[str, ...]:
71 return self.get_localsplus_names(CO_FAST_LOCAL)
72
73 @property
74 def co_cellvars(self) -> Tuple[str, ...]:
75 return self.get_localsplus_names(CO_FAST_CELL)
76
77 @property
78 def co_freevars(self) -> Tuple[str, ...]:
79 return self.get_localsplus_names(CO_FAST_FREE)
80
81 @property
82 def co_nlocals(self) -> int:
83 return len(self.co_varnames)
84
85
86 class ESC[4;38;5;81mReader:
87 # A fairly literal translation of the marshal reader.
88
89 def __init__(self, data: bytes):
90 self.data: bytes = data
91 self.end: int = len(self.data)
92 self.pos: int = 0
93 self.refs: list[Any] = []
94 self.level: int = 0
95
96 def r_string(self, n: int) -> bytes:
97 assert 0 <= n <= self.end - self.pos
98 buf = self.data[self.pos : self.pos + n]
99 self.pos += n
100 return buf
101
102 def r_byte(self) -> int:
103 buf = self.r_string(1)
104 return buf[0]
105
106 def r_short(self) -> int:
107 buf = self.r_string(2)
108 x = buf[0]
109 x |= buf[1] << 8
110 x |= -(x & (1<<15)) # Sign-extend
111 return x
112
113 def r_long(self) -> int:
114 buf = self.r_string(4)
115 x = buf[0]
116 x |= buf[1] << 8
117 x |= buf[2] << 16
118 x |= buf[3] << 24
119 x |= -(x & (1<<31)) # Sign-extend
120 return x
121
122 def r_long64(self) -> int:
123 buf = self.r_string(8)
124 x = buf[0]
125 x |= buf[1] << 8
126 x |= buf[2] << 16
127 x |= buf[3] << 24
128 x |= buf[4] << 32
129 x |= buf[5] << 40
130 x |= buf[6] << 48
131 x |= buf[7] << 56
132 x |= -(x & (1<<63)) # Sign-extend
133 return x
134
135 def r_PyLong(self) -> int:
136 n = self.r_long()
137 size = abs(n)
138 x = 0
139 # Pray this is right
140 for i in range(size):
141 x |= self.r_short() << i*15
142 if n < 0:
143 x = -x
144 return x
145
146 def r_float_bin(self) -> float:
147 buf = self.r_string(8)
148 import struct # Lazy import to avoid breaking UNIX build
149 return struct.unpack("d", buf)[0]
150
151 def r_float_str(self) -> float:
152 n = self.r_byte()
153 buf = self.r_string(n)
154 return ast.literal_eval(buf.decode("ascii"))
155
156 def r_ref_reserve(self, flag: int) -> int:
157 if flag:
158 idx = len(self.refs)
159 self.refs.append(None)
160 return idx
161 else:
162 return 0
163
164 def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any:
165 if flag:
166 self.refs[idx] = obj
167 return obj
168
169 def r_ref(self, obj: Any, flag: int) -> Any:
170 assert flag & FLAG_REF
171 self.refs.append(obj)
172 return obj
173
174 def r_object(self) -> Any:
175 old_level = self.level
176 try:
177 return self._r_object()
178 finally:
179 self.level = old_level
180
181 def _r_object(self) -> Any:
182 code = self.r_byte()
183 flag = code & FLAG_REF
184 type = code & ~FLAG_REF
185 # print(" "*self.level + f"{code} {flag} {type} {chr(type)!r}")
186 self.level += 1
187
188 def R_REF(obj: Any) -> Any:
189 if flag:
190 obj = self.r_ref(obj, flag)
191 return obj
192
193 if type == Type.NULL:
194 return NULL
195 elif type == Type.NONE:
196 return None
197 elif type == Type.ELLIPSIS:
198 return Ellipsis
199 elif type == Type.FALSE:
200 return False
201 elif type == Type.TRUE:
202 return True
203 elif type == Type.INT:
204 return R_REF(self.r_long())
205 elif type == Type.INT64:
206 return R_REF(self.r_long64())
207 elif type == Type.LONG:
208 return R_REF(self.r_PyLong())
209 elif type == Type.FLOAT:
210 return R_REF(self.r_float_str())
211 elif type == Type.BINARY_FLOAT:
212 return R_REF(self.r_float_bin())
213 elif type == Type.COMPLEX:
214 return R_REF(complex(self.r_float_str(),
215 self.r_float_str()))
216 elif type == Type.BINARY_COMPLEX:
217 return R_REF(complex(self.r_float_bin(),
218 self.r_float_bin()))
219 elif type == Type.STRING:
220 n = self.r_long()
221 return R_REF(self.r_string(n))
222 elif type == Type.ASCII_INTERNED or type == Type.ASCII:
223 n = self.r_long()
224 return R_REF(self.r_string(n).decode("ascii"))
225 elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII:
226 n = self.r_byte()
227 return R_REF(self.r_string(n).decode("ascii"))
228 elif type == Type.INTERNED or type == Type.UNICODE:
229 n = self.r_long()
230 return R_REF(self.r_string(n).decode("utf8", "surrogatepass"))
231 elif type == Type.SMALL_TUPLE:
232 n = self.r_byte()
233 idx = self.r_ref_reserve(flag)
234 retval: Any = tuple(self.r_object() for _ in range(n))
235 self.r_ref_insert(retval, idx, flag)
236 return retval
237 elif type == Type.TUPLE:
238 n = self.r_long()
239 idx = self.r_ref_reserve(flag)
240 retval = tuple(self.r_object() for _ in range(n))
241 self.r_ref_insert(retval, idx, flag)
242 return retval
243 elif type == Type.LIST:
244 n = self.r_long()
245 retval = R_REF([])
246 for _ in range(n):
247 retval.append(self.r_object())
248 return retval
249 elif type == Type.DICT:
250 retval = R_REF({})
251 while True:
252 key = self.r_object()
253 if key == NULL:
254 break
255 val = self.r_object()
256 retval[key] = val
257 return retval
258 elif type == Type.SET:
259 n = self.r_long()
260 retval = R_REF(set())
261 for _ in range(n):
262 v = self.r_object()
263 retval.add(v)
264 return retval
265 elif type == Type.FROZENSET:
266 n = self.r_long()
267 s: set[Any] = set()
268 idx = self.r_ref_reserve(flag)
269 for _ in range(n):
270 v = self.r_object()
271 s.add(v)
272 retval = frozenset(s)
273 self.r_ref_insert(retval, idx, flag)
274 return retval
275 elif type == Type.CODE:
276 retval = R_REF(Code())
277 retval.co_argcount = self.r_long()
278 retval.co_posonlyargcount = self.r_long()
279 retval.co_kwonlyargcount = self.r_long()
280 retval.co_stacksize = self.r_long()
281 retval.co_flags = self.r_long()
282 retval.co_code = self.r_object()
283 retval.co_consts = self.r_object()
284 retval.co_names = self.r_object()
285 retval.co_localsplusnames = self.r_object()
286 retval.co_localspluskinds = self.r_object()
287 retval.co_filename = self.r_object()
288 retval.co_name = self.r_object()
289 retval.co_qualname = self.r_object()
290 retval.co_firstlineno = self.r_long()
291 retval.co_linetable = self.r_object()
292 retval.co_exceptiontable = self.r_object()
293 return retval
294 elif type == Type.REF:
295 n = self.r_long()
296 retval = self.refs[n]
297 assert retval is not None
298 return retval
299 else:
300 breakpoint()
301 raise AssertionError(f"Unknown type {type} {chr(type)!r}")
302
303
304 def loads(data: bytes) -> Any:
305 assert isinstance(data, bytes)
306 r = Reader(data)
307 return r.r_object()
308
309
310 def main():
311 # Test
312 import marshal, pprint
313 sample = {'foo': {(42, "bar", 3.14)}}
314 data = marshal.dumps(sample)
315 retval = loads(data)
316 assert retval == sample, retval
317 sample = main.__code__
318 data = marshal.dumps(sample)
319 retval = loads(data)
320 assert isinstance(retval, Code), retval
321 pprint.pprint(retval.__dict__)
322
323
324 if __name__ == "__main__":
325 main()