_backend.py
5.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import scipy._lib.uarray as ua
from . import _pocketfft
class _ScipyBackend:
"""The default backend for fft calculations
Notes
-----
We use the domain ``numpy.scipy`` rather than ``scipy`` because in the
future, ``uarray`` will treat the domain as a hierarchy. This means the user
can install a single backend for ``numpy`` and have it implement
``numpy.scipy.fft`` as well.
"""
__ua_domain__ = "numpy.scipy.fft"
@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_pocketfft, method.__name__, None)
if fn is None:
return NotImplemented
return fn(*args, **kwargs)
_named_backends = {
'scipy': _ScipyBackend,
}
def _backend_from_arg(backend):
"""Maps strings to known backends and validates the backend"""
if isinstance(backend, str):
try:
backend = _named_backends[backend]
except KeyError:
raise ValueError('Unknown backend {}'.format(backend))
if backend.__ua_domain__ != 'numpy.scipy.fft':
raise ValueError('Backend does not implement "numpy.scipy.fft"')
return backend
def set_global_backend(backend):
"""Sets the global fft backend
The global backend has higher priority than registered backends, but lower
priority than context-specific backends set with `set_backend`.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``.
Notes
-----
This will overwrite the previously set global backend, which, by default, is
the SciPy implementation.
Examples
--------
We can set the global fft backend:
>>> from scipy.fft import fft, set_global_backend
>>> set_global_backend("scipy") # Sets global backend. "scipy" is the default backend.
>>> fft([1]) # Calls the global backend
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
ua.set_global_backend(backend)
def register_backend(backend):
"""
Register a backend for permanent use.
Registered backends have the lowest priority and will be tried after the
global backend.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``.
Examples
--------
We can register a new fft backend:
>>> from scipy.fft import fft, register_backend, set_global_backend
>>> class NoopBackend: # Define an invalid Backend
... __ua_domain__ = "numpy.scipy.fft"
... def __ua_function__(self, func, args, kwargs):
... return NotImplemented
>>> set_global_backend(NoopBackend()) # Set the invalid backend as global
>>> register_backend("scipy") # Register a new backend
>>> fft([1]) # The registered backend is called because the global backend returns `NotImplemented`
array([1.+0.j])
>>> set_global_backend("scipy") # Restore global backend to default
"""
backend = _backend_from_arg(backend)
ua.register_backend(backend)
def set_backend(backend, coerce=False, only=False):
"""Context manager to set the backend within a fixed scope.
Upon entering the ``with`` statement, the given backend will be added to
the list of available backends with the highest priority. Upon exit, the
backend is reset to the state before entering the scope.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
coerce: bool, optional
Whether to allow expensive conversions for the ``x`` parameter. e.g.,
copying a NumPy array to the GPU for a CuPy backend. Implies ``only``.
only: bool, optional
If only is ``True`` and this backend returns ``NotImplemented``, then a
BackendNotImplemented error will be raised immediately. Ignoring any
lower priority backends.
Examples
--------
>>> import scipy.fft as fft
>>> with fft.set_backend('scipy', only=True):
... fft.fft([1]) # Always calls the scipy implementation
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
return ua.set_backend(backend, coerce=coerce, only=only)
def skip_backend(backend):
"""Context manager to skip a backend within a fixed scope.
Within the context of a ``with`` statement, the given backend will not be
called. This covers backends registered both locally and globally. Upon
exit, the backend will again be considered.
Parameters
----------
backend: {object, 'scipy'}
The backend to skip.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Examples
--------
>>> import scipy.fft as fft
>>> fft.fft([1]) # Calls default SciPy backend
array([1.+0.j])
>>> with fft.skip_backend('scipy'): # We explicitly skip the SciPy backend
... fft.fft([1]) # leaving no implementation available
Traceback (most recent call last):
...
BackendNotImplementedError: No selected backends had an implementation ...
"""
backend = _backend_from_arg(backend)
return ua.skip_backend(backend)
set_global_backend('scipy')