Saturday, December 2, 2023
HomePythonArray API Assist in scikit-learn

Array API Assist in scikit-learn

The Consortium for Python Knowledge APIs Requirements developed the Array API normal, which goals to outline constant habits between the ecosystem of array libraries, resembling PyTorch, NumPy, and CuPy. The Array API normal allows libraries, resembling scikit-learn, to jot down code as soon as with the usual and have it work on a number of array libraries. With PyTorch tensors or CuPy arrays, it’s now doable to run computations on accelerators, resembling GPUs.

With the discharge of scikit-learn 1.3, we enabled experimental Array API help for a restricted set of machine studying fashions. Array API help is step by step increasing to incorporate extra machine studying fashions and performance on the event department. Scikit-learn relies on the array_api_compat library for Array API help. array_api_compat extends the Array API normal to the principle namespaces of NumPy’s arrays, CuPy’s arrays, and PyTorch’s tensors. On this weblog put up, we cowl scikit-learn’s public interface for enabling Array API, the efficiency acquire of working on an accelerator, and the challenges we confronted when integrating Array API.


Scikit-learn was initially developed to run on CPUs with NumPy arrays. With Array API help, a restricted set of scikit-learn fashions and instruments can now run with different array libraries and units like GPUs. The next benchmark outcomes are from working scikit-learn’s LinearDiscriminantAnalysis utilizing NumPy and PyTorch on a AMD 5950x CPU and PyTorch on a Nvidia RTX 3090 GPU.

Bar plot with benchmark results comparing NumPy and PyTorch on a AMD 5950x CPU and PyTorch on a Nvidia RTX 3090 GPU running Linear Discriminant Analysis. The PyTorch GPU results are marked as 27 times faster compared to NumPy for fitting the model and 28 times faster compared to NumPy for prediction.

The coaching and prediction occasions are improved when utilizing PyTorch in comparison with NumPy. Operating the computation on PyTorch CPU tensors is quicker than NumPy as a result of PyTorch CPU operations are multi-threaded by default.

scikit-learn’s Array API interface

Scikit-learn prolonged its experimental Array API help in model 1.3 to help NumPy’s ndarrays, CuPy’s ndarrays, and PyTorch’s Tensors. By themselves, these array objects don’t implement the Array API specification absolutely but. To beat this limitation, Quansight engineer Aaron Meurer led the event of array_api_compat to bridge any gaps and supply Array API compatibility for NumPy, CuPy, and PyTorch. Scikit-learn immediately makes use of array_api_compat for its Array API help. There are two methods of enabling Array API in scikit-learn: by a world configuration and a context supervisor. The next instance makes use of a context supervisor:

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from sklearn.datasets import make_classification

X_np, y_np = make_classification(random_state=0, n_samples=500_000, n_features=300)

X_torch_cpu, y_torch_cpu = torch.asarray(X_np), torch.asarray(y_np)

with sklearn.config_context(array_api_dispatch=True):

lda_torch_cpu = LinearDiscriminantAnalysis()

lda_torch_cpu.match(X_torch_cpu, y_torch_cpu)

predictions = lda_torch_cpu.predict(X_torch_cpu)

Word how the estimator’s output returns an array from the enter’s array library. This following instance makes use of the worldwide configuration and PyTorch Tensors on GPUs:


X_torch_cuda = torch.asarray(X_np, machine="cuda")

y_torch_cuda = torch.asarray(y_np, machine="cuda")

lda_torch_cuda = LinearDiscriminantAnalysis()

lda_torch_cuda.match(X_torch_cuda, y_torch_cuda)

predictions = lda_torch_cuda.predict(X_torch_cuda)


A typical machine studying use case is to coach a mannequin on a GPU after which switch it to a CPU for deployment. Scikit-learn gives a personal utility perform to deal with this machine motion:

from sklearn.utils._array_api import _estimator_with_converted_arrays

tensor_to_ndarray = lambda array : array.cpu().numpy()

lda_np = _estimator_with_converted_arrays(lda_torch_cuda, tensor_to_ndarray)

X_trans = lda_np.rework(X_np)

# <class 'numpy.ndarray'>

You possibly can study extra about Scikit-learn’s Array API help of their documentation.


Adopting the Array API normal in scikit-learn was not a simple activity and required us to beat some challenges. On this part, we cowl the 2 most important challenges:

  • The Array API Commonplace is a subset of NumPy’s API.
  • Compiled code that solely runs on CPUs as a result of it was written in C, C++, or Cython.

Array API Commonplace is a subset of NumPy’s API

NumPy’s API is intensive and features a huge quantity of operations. By design, the Array API normal is a subset of the NumPy API. For performance to be included within the Array API normal, it have to be carried out by most array libraries and extensively used. Scikit-learn’s codebase was initially written to make use of NumPy’s API. As a way to undertake the Array API, we needed to rewrite some NumPy capabilities by way of Array API capabilities. For instance, nanmin will not be part of the Array API normal, so we had been required to implement it:

def _nanmin(X, axis=None):

xp = get_array_namespace(X)

if _is_numpy_namespace(xp):

return xp.asarray(numpy.nanmin(X, axis=axis))

# Implements nanmin by way of the Array API normal

X = xp.min(xp.the place(masks, ...), axis=axis)

The NumPy arrays are nonetheless dispatched to np.nanmin, whereas all different libraries undergo an implementation that makes use of the Array API normal.

There may be an open subject within the Array API repo to debate bringing nanmin into the usual. Traditionally, this strategy of introducing new performance has been profitable. For instance, take was launched into the Array API normal in v2022.12, as a result of we proposed it within the Array API repo. The group decided that deciding on parts of an array with indices was an ordinary operation, so that they launched take into the up to date normal.

The Array API normal contains non-compulsory extensions for linear algebra and Fourier transforms. These non-compulsory extensions are generally carried out throughout array libraries, however will not be required by the Array API normal. As a machine studying library, scikit-learn extensively use the linalg module for computation. The Array API normal for NumPy arrays will name numpy.linalg and never scipy.linalg, which has refined variations. We determined to be conservative and preserve backward compatibility by dispatching NumPy arrays to SciPy:

# True if Array API is enabled and enter follows the usual

is_array_api_compliant = ...

if is_array_api_compliant:

# makes use of svd for computation

This implementation was a compromise to make sure that NumPy arrays undergo the identical code path as earlier than and have the identical efficiency traits as earlier scikit-learn variations.

Compiled Code

Scikit-learn incorporates a mix of Python code and compiled code in Cython, C, and C++. For instance, typical machine studying algorithms resembling random forests, linear fashions, and gradient boosted timber all have compiled code. On condition that the Array API normal is a Python API, it’s most efficient to adapt scikit-learn’s Python code to make use of the usual. This limitation restricts the quantity of performance that may make the most of Array API.

At the moment, scikit-learn contributors are experimenting with a plugin system to dispatch compiled code to exterior libraries. Though there’s compiled code, Array API will nonetheless play a essential position to get the plugin system up and working. For instance, scikit-learn generally preforms computation in Python earlier than and after dispatching to an exterior library:

X_prep, y_prep = preprocess_X_y(X, y)

model_state_array = plugin_dispatched_op(X_prep, y_prep)

self.model_state_post = post_process(model_state_array)

With plugins, the dispatched code will ingest and return arrays that comply with the Array API normal. The usual defines a standard Python interface for preprocessing and postprocessing arrays.


In recent times, there was rising utilization of accelerators for computation in lots of domains. The array API normal offers Python libraries like scikit-learn entry to those accelerators with the identical supply code. Relying in your code, there are numerous challenges for adopting Array API, however there are efficiency and compatibility advantages from utilizing the API. In case you observe any limitations, you’re welcome to open points on their subject tracker. For extra details about Array API, it’s possible you’ll watch Aaron’s SciPy presentation, learn the SciPy proceedings paper, or learn the Array API documentation.

This work was made doable by Meta funding the trouble, enabling us to make progress on this subject shortly. This subject was a longer-term aim on scikit-learn’s roadmap for fairly a while. Related steps are beneath technique to incorporate the Array API Commonplace into SciPy. Because the adoption of the Array API Commonplace will increase, we intention to make it simpler for area libraries and their customers to raised make the most of their {hardware} for computation.

I need to thank Aaron Meurer, Matthew Barber, and Ralf Gommers for the event of array_api_compat, which was a significant a part of this challenge’s success. I additionally need to thank Olivier Grisel and Tim Head for serving to with this challenge and persevering with to push ahead on increasing help.



Please enter your comment!
Please enter your name here

Most Popular

Recent Comments