1 import os
2 import pathlib
3 import tempfile
4 import functools
5 import contextlib
6 import types
7 import importlib
8 import inspect
9 import warnings
10 import itertools
11
12 from typing import Union, Optional, cast
13 from .abc import ResourceReader, Traversable
14
15 from ._adapters import wrap_spec
16
17 Package = Union[types.ModuleType, str]
18 Anchor = Package
19
20
21 def package_to_anchor(func):
22 """
23 Replace 'package' parameter as 'anchor' and warn about the change.
24
25 Other errors should fall through.
26
27 >>> files('a', 'b')
28 Traceback (most recent call last):
29 TypeError: files() takes from 0 to 1 positional arguments but 2 were given
30 """
31 undefined = object()
32
33 @functools.wraps(func)
34 def wrapper(anchor=undefined, package=undefined):
35 if package is not undefined:
36 if anchor is not undefined:
37 return func(anchor, package)
38 warnings.warn(
39 "First parameter to files is renamed to 'anchor'",
40 DeprecationWarning,
41 stacklevel=2,
42 )
43 return func(package)
44 elif anchor is undefined:
45 return func()
46 return func(anchor)
47
48 return wrapper
49
50
51 @package_to_anchor
52 def files(anchor: Optional[Anchor] = None) -> Traversable:
53 """
54 Get a Traversable resource for an anchor.
55 """
56 return from_package(resolve(anchor))
57
58
59 def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
60 """
61 Return the package's loader if it's a ResourceReader.
62 """
63 # We can't use
64 # a issubclass() check here because apparently abc.'s __subclasscheck__()
65 # hook wants to create a weak reference to the object, but
66 # zipimport.zipimporter does not support weak references, resulting in a
67 # TypeError. That seems terrible.
68 spec = package.__spec__
69 reader = getattr(spec.loader, 'get_resource_reader', None) # type: ignore
70 if reader is None:
71 return None
72 return reader(spec.name) # type: ignore
73
74
75 @functools.singledispatch
76 def resolve(cand: Optional[Anchor]) -> types.ModuleType:
77 return cast(types.ModuleType, cand)
78
79
80 @resolve.register
81 def _(cand: str) -> types.ModuleType:
82 return importlib.import_module(cand)
83
84
85 @resolve.register
86 def _(cand: None) -> types.ModuleType:
87 return resolve(_infer_caller().f_globals['__name__'])
88
89
90 def _infer_caller():
91 """
92 Walk the stack and find the frame of the first caller not in this module.
93 """
94
95 def is_this_file(frame_info):
96 return frame_info.filename == __file__
97
98 def is_wrapper(frame_info):
99 return frame_info.function == 'wrapper'
100
101 not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
102 # also exclude 'wrapper' due to singledispatch in the call stack
103 callers = itertools.filterfalse(is_wrapper, not_this_file)
104 return next(callers).frame
105
106
107 def from_package(package: types.ModuleType):
108 """
109 Return a Traversable object for the given package.
110
111 """
112 spec = wrap_spec(package)
113 reader = spec.loader.get_resource_reader(spec.name)
114 return reader.files()
115
116
117 @contextlib.contextmanager
118 def _tempfile(
119 reader,
120 suffix='',
121 # gh-93353: Keep a reference to call os.remove() in late Python
122 # finalization.
123 *,
124 _os_remove=os.remove,
125 ):
126 # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
127 # blocks due to the need to close the temporary file to work on Windows
128 # properly.
129 fd, raw_path = tempfile.mkstemp(suffix=suffix)
130 try:
131 try:
132 os.write(fd, reader())
133 finally:
134 os.close(fd)
135 del reader
136 yield pathlib.Path(raw_path)
137 finally:
138 try:
139 _os_remove(raw_path)
140 except FileNotFoundError:
141 pass
142
143
144 def _temp_file(path):
145 return _tempfile(path.read_bytes, suffix=path.name)
146
147
148 def _is_present_dir(path: Traversable) -> bool:
149 """
150 Some Traversables implement ``is_dir()`` to raise an
151 exception (i.e. ``FileNotFoundError``) when the
152 directory doesn't exist. This function wraps that call
153 to always return a boolean and only return True
154 if there's a dir and it exists.
155 """
156 with contextlib.suppress(FileNotFoundError):
157 return path.is_dir()
158 return False
159
160
161 @functools.singledispatch
162 def as_file(path):
163 """
164 Given a Traversable object, return that object as a
165 path on the local file system in a context manager.
166 """
167 return _temp_dir(path) if _is_present_dir(path) else _temp_file(path)
168
169
170 @as_file.register(pathlib.Path)
171 @contextlib.contextmanager
172 def _(path):
173 """
174 Degenerate behavior for pathlib.Path objects.
175 """
176 yield path
177
178
179 @contextlib.contextmanager
180 def _temp_path(dir: tempfile.TemporaryDirectory):
181 """
182 Wrap tempfile.TemporyDirectory to return a pathlib object.
183 """
184 with dir as result:
185 yield pathlib.Path(result)
186
187
188 @contextlib.contextmanager
189 def _temp_dir(path):
190 """
191 Given a traversable dir, recursively replicate the whole tree
192 to the file system in a context manager.
193 """
194 assert path.is_dir()
195 with _temp_path(tempfile.TemporaryDirectory()) as temp_dir:
196 yield _write_contents(temp_dir, path)
197
198
199 def _write_contents(target, source):
200 child = target.joinpath(source.name)
201 if source.is_dir():
202 child.mkdir()
203 for item in source.iterdir():
204 _write_contents(child, item)
205 else:
206 child.write_bytes(source.read_bytes())
207 return child