# JAX bindings to FINUFFT

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

## Included features

*This library is currently CPU-only, but GPU support is in the works using the cuFINUFFT library.*

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using `vmap`

.

## Installation

*For now, only a source build is supported.*

For building, you should only need a recent version of Python (>3.6) and FFTW. At runtime, you'll need `numpy`

, `scipy`

, and `jax`

. To set up such an environment, you can use `conda`

(but you're welcome to use whatever workflow works for you!):

```
conda create -n jax-finufft -c conda-forge python=3.9 numpy scipy fftw
python -m pip install "jax[cpu]"
```

Then you can install from source using (don't forget the `--recursive`

flag because FINUFFT is included as a submodule):

```
git clone --recursive https://github.com/dfm/jax-finufft
cd jax-finufft
python -m pip install .
```

## Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): `nufft1`

and `nufft2`

(for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, *please note that the function signatures here are different*!

For example, here's how you can do a 1-dimensional type 1 transform:

```
import numpy as np
from jax_finufft import nufft1
M = 100000
N = 200000
x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)
```

Noting that the `eps`

and `iflag`

are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the `finufft`

Python package.

The syntax for a 2-, or 3-dimensional transform is:

```
f = nufft1((Nx, Ny), c, x, y) # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z) # 3D
```

The syntax for a type 2 transform is (also allowing optional `iflag`

and `eps`

parameters):

```
c = nufft2(f, x) # 1D
c = nufft2(f, x, y) # 2D
c = nufft2(f, x, y, z) # 3D
```

## Similar libraries

- finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
- mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

## License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.