How to write fast (or slow) Python code

How to write fast (or slow) Python code#

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

import numpy as np

Avoid large python loops#

def matMult(A, B):
    C = np.zeros((A.shape[0], B.shape[1]))

    for i in range(A.shape[0]):
        for j in range(B.shape[1]):
            for k in range(A.shape[1]):
                C[i, j] += A[i, k] * B[k, j]
    return C


N = 100
M = 200
K = 300
A = np.random.rand(N, M)
B = np.random.rand(M, K)

%timeit matMult(A,B)
3.5 s ± 46.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Now let’s compare that to NumPy’s matrix-matrix multiplication with the same matrices:

%timeit A @ B
330 μs ± 3.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Numpy is literally 10,000x faster than our python code!

When using JAX, avoid modifying array values in place#

# Define the symbolic function q(x)
def q(x, y, Lx, Ly, kappa, coeff):
    return (x / Lx * (1 - x / Lx) + y / Ly * (1 - y / Ly)) * (2 * coeff)


def heat_conduction_2D_slow(Lx=2.0, Ly=1.0, Nx=10, Ny=5, kappa=0.5, coeff=1.0, T0=0.0):
    dx = Lx / Nx  # Grid spacing in the x-direction
    dy = Ly / Ny  # Grid spacing in the y-direction

    assert dx == dy, "dx must be equal to dy for the 9-point stencil to work"

    h = dx

    # Create a 2D grid of x and y coordinates
    x = jnp.linspace(0, Lx, Nx + 1)
    y = jnp.linspace(0, Ly, Ny + 1)
    Y, X = jnp.meshgrid(y, x)
    N = (Nx + 1) * (Ny + 1)
    rowOffset = Nx + 1

    # Create the Laplacian operator for 2D using finite differences
    A = jnp.zeros([N, N])
    b = jnp.zeros(N)

    for iy in range(1, Ny):
        for ix in range(1, Nx):
            row = iy * rowOffset + ix  # Current row in matrix
            x = ix * dx
            y = iy * dy

            A = A.at[row, row - rowOffset - 1].set(-1 / (4 * h**2))
            A = A.at[row, row - rowOffset].set(-1 / (2 * h**2))
            A = A.at[row, row - rowOffset + 1].set(-1 / (4 * h**2))
            A = A.at[row, row - 1].set(-1 / (2 * h**2))
            A = A.at[row, row].set(3 / h**2)
            A = A.at[row, row + 1].set(-1 / (2 * h**2))
            A = A.at[row, row + rowOffset - 1].set(-1 / (4 * h**2))
            A = A.at[row, row + rowOffset].set(-1 / (2 * h**2))
            A = A.at[row, row + rowOffset + 1].set(-1 / (4 * h**2))
            b = b.at[row].set(q(x, y, Lx, Ly, kappa, coeff) / kappa)

    # enforce boundary conditions
    for ix in range(Nx + 1):
        i = ix
        A = A.at[i, i].set(1.0)
        b = b.at[i].set(T0)
        i = Ny * (Nx + 1) + ix
        A = A.at[i, i].set(1.0)
        b = b.at[i].set(T0)
    for iy in range(Ny + 1):
        i = iy * (Nx + 1)
        A = A.at[i, i].set(1.0)
        b = b.at[i].set(T0)

        i = iy * (Nx + 1) + Nx
        A = A.at[i, i].set(1.0)
        b = b.at[i].set(T0)

    # Solve the linear system
    T = jnp.linalg.solve(A, b)
    T_out = jnp.reshape(T, (Nx + 1, Ny + 1), order="F")  # reshape into matrix
    return T_out, X, Y


%timeit heat_conduction_2D_slow(Nx=20, Ny=10)
1.28 s ± 5.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
def heat_conduction_2D_fast(Lx=2.0, Ly=1.0, Nx=40, Ny=20, kappa=0.5, coeff=1.0, T0=0.0):
    dx = Lx / Nx  # Grid spacing in the x-direction
    dy = Ly / Ny  # Grid spacing in the y-direction

    assert dx == dy, "dx must be equal to dy for the 9-point stencil to work"

    h = dx

    # Create a 2D grid of x and y coordinates
    x = jnp.linspace(0, Lx, Nx + 1)
    y = jnp.linspace(0, Ly, Ny + 1)
    Y, X = jnp.meshgrid(y, x)
    N = (Nx + 1) * (Ny + 1)
    rowOffset = Nx + 1

    # Initialise A matrix as identity so the boundary conditions are already set
    A = jnp.eye(N)

    # Initialise b vector to all T0 so that boundary condition rows are already set
    b = jnp.ones(N) * T0

    # Initialise lists that will store the row, column and value of each non-zero element in A related to the non-boundary nodes
    Arows = []
    Acols = []
    Avals = []
    bRows = []
    bVals = []

    weights = [
        -0.25 / h**2,
        -0.5 / h**2,
        -0.25 / h**2,
        -0.5 / h**2,
        3.0 / h**2,
        -0.5 / h**2,
        -0.25 / h**2,
        -0.5 / h**2,
        -0.25 / h**2,
    ]

    for iy in range(1, Ny):
        for ix in range(1, Nx):
            row = iy * rowOffset + ix  # Current row in matrix
            x = ix * dx
            y = iy * dy

            Arows += [row] * 9
            Acols += [
                row - rowOffset - 1,
                row - rowOffset,
                row - rowOffset + 1,
                row - 1,
                row,
                row + 1,
                row + rowOffset - 1,
                row + rowOffset,
                row + rowOffset + 1,
            ]
            Avals += weights
            bRows.append(row)
            bVals.append(q(x, y, Lx, Ly, kappa, coeff) / kappa)

    # Now actually set the values in the matrix all in one go
    A = A.at[jnp.array(Arows), jnp.array(Acols)].set(jnp.array(Avals))
    b = b.at[jnp.array(bRows)].set(jnp.array(bVals))

    # Solve the linear system
    T = jnp.linalg.solve(A, b)
    T_out = jnp.reshape(T, (Nx + 1, Ny + 1), order="F")  # reshape into matrix
    return T_out, X, Y


%timeit heat_conduction_2D_fast(Nx=20, Ny=10)

T_slow = heat_conduction_2D_slow(Nx=20, Ny=10)[0]
T_fast = heat_conduction_2D_fast(Nx=20, Ny=10)[0]

results_match = jnp.allclose(T_slow, T_fast)
if results_match:
    print("The fast and slow versions of the code give the same results!")
else:
    print("The fast and slow versions of the code give different results!")
19.7 ms ± 403 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
The fast and slow versions of the code give the same results!