Speeding up matrix multiplication ~ 5 million times

Speeding up matrix multiplication in PyTorch
PyTorch
Published

January 1, 2024

from pathlib import Path
import pickle, gzip, math, os, time, shutil, matplotlib as mpl, matplotlib.pyplot as plt
import torch
import numpy as np
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
np.set_printoptions(precision=2, linewidth=140)
torch.manual_seed(1)
weights = torch.randn(100,500) #flattenned out mnist digit * 10 possible digits
bias = torch.zeros(500)
A = torch.randn(5,1000)
B = torch.randn(1000,500)
A.shape,B.shape
(torch.Size([5, 1000]), torch.Size([1000, 500]))
Ar,Ac = A.shape # n_rows * n_cols
Br,Bc = B.shape
(Ar,Ac),(Br,Bc)
((5, 1000), (1000, 500))
C = torch.zeros(Ar, Bc) # will store product of A and B
C.shape
torch.Size([5, 500])

A naive matmul for benchmarking

for i in range(Ar):
  for j in range(Bc):
    for k in range(Ac):
      C[i,j] += A[i,k] * B[k,j]
C
tensor([[ -47.17,  -54.96,  -10.88,  ...,   21.26,    4.28,   -4.78],
        [  77.24,   23.64,  -20.62,  ...,   13.62,  -26.03,   22.42],
        [-108.64,   27.16,   49.40,  ...,   27.56,    9.35,   16.46],
        [ -13.44,   45.17,   -2.30,  ...,  -79.52,  -58.32,  -13.37],
        [ -21.50,  -12.12,   55.95,  ...,   28.55,  -32.96,  -35.81]])
C.shape
torch.Size([5, 500])
  • Also have PyTorch produce a matrix product to use as a benchmark for mathematical correctness.
reference = torch.mm(A,B)
reference
tensor([[ -47.17,  -54.96,  -10.87,  ...,   21.26,    4.28,   -4.78],
        [  77.24,   23.64,  -20.62,  ...,   13.62,  -26.03,   22.42],
        [-108.64,   27.16,   49.40,  ...,   27.56,    9.35,   16.46],
        [ -13.44,   45.17,   -2.30,  ...,  -79.52,  -58.32,  -13.37],
        [ -21.50,  -12.12,   55.95,  ...,   28.55,  -32.96,  -35.81]])
torch.allclose(C.to('cpu'), reference.to('cpu'),atol=1e-04, rtol=1e-04)
True
def matmul_naive(A,B):
  """
  Perform naive matrix multiplication matrices of A and B
  """
  Ar, Ac = A.shape
  Br, Bc = B.shape
  C = torch.zeros(Ar, Bc)
  for i in range(Ar):
    for j in range(Bc):
      for k in range(Ac):
        C[i,j] += A[i,k] * B[k,j]
  return C
torch.allclose(matmul_naive(A,B), reference,atol=1e-04, rtol=1e-04)
True
print(f"Performed O({Ar*Bc*Ac}) operations")
Performed O(2500000) operations

Matmul with numba for speeding up the dot product

from numba import njit
from numpy import array
@njit
def dot(a,b):
  res = 0.
  for i in range(len(a)): res += a[i]*b[i]
  return res
dot(array([1,2,3]),array([2,0,1]))
5.0
def matmul_numba(A,B):
  """
  Perform matrix multiplication matrices of A and B with
  inner product optimized with numba
  """
  Ar, Ac = A.shape
  Br, Bc = B.shape
  C = torch.zeros(Ar, Bc)
  for i in range(Ar):
    for j in range(Bc):
      C[i,j] = dot(A[i,:],B[:,j])
  return C
matmul_numba(A.numpy(),B.numpy())
tensor([[ -47.17,  -54.96,  -10.88,  ...,   21.26,    4.28,   -4.78],
        [  77.24,   23.64,  -20.62,  ...,   13.62,  -26.03,   22.42],
        [-108.64,   27.16,   49.40,  ...,   27.56,    9.35,   16.46],
        [ -13.44,   45.17,   -2.30,  ...,  -79.52,  -58.32,  -13.37],
        [ -21.50,  -12.12,   55.95,  ...,   28.55,  -32.96,  -35.81]])
np.allclose(matmul_numba(A.numpy(),B.numpy()), reference.numpy())
False
Anp, Bnp = A.numpy(), B.numpy()
%timeit -n 50 matmul_numba(Anp,Bnp)
7.14 ms ± 31.3 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)
print(f"Speedup factor over naive matmul: {740000/344:.0f}")
Speedup factor over naive matmul: 2151

Matmul with PyTorch inner product

def matmul_innertorch(A,B):
  Ar, Ac = A.shape
  Br, Bc = B.shape
  C = torch.zeros(Ar, Bc)
  for i in range(Ar):
    for j in range(Bc):
      C[i,j] = (A[i,:]*B[:,j]).sum()
  return C
matmul_innertorch(A,B)
tensor([[ -47.17,  -54.96,  -10.88,  ...,   21.26,    4.28,   -4.78],
        [  77.24,   23.64,  -20.62,  ...,   13.62,  -26.03,   22.42],
        [-108.64,   27.16,   49.40,  ...,   27.56,    9.35,   16.46],
        [ -13.44,   45.17,   -2.30,  ...,  -79.52,  -58.32,  -13.37],
        [ -21.50,  -12.12,   55.95,  ...,   28.55,  -32.96,  -35.81]])
torch.allclose(reference,matmul_innertorch(A, B),atol=1e-04, rtol=1e-04)
True
%timeit -n 50 _=matmul_innertorch(A, B)
18 ms ± 575 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)
print(f"Speedup factor over naive matmul: {740000/882:.0f}")
Speedup factor over naive matmul: 839

Matmul with broadcasting

  • We can multiply each row of A by all columns of B simultaneously.
  • A[i,:] is [1000] of size while B is of size [1000, 10].
  • By adding an extra dimension to A[i,:] via A[i,:,None] (or A[i,:].unsqueeze(1)), we get shape [1000, 1] and are able to broadcast along the column dimension.
    -To put it losely, each row of A is stretched out into a column and multiplied by B; when summed, this yields a row of the final product. We then need to iterate along the rows only, reducing the number of for loops from three with naive matmul to one with broadcasting.
print(f"""B.shape: {B.shape} \n\nA[i,:].shape: {A[i,:].shape} \n
A[i,:].unsqueeze(1).shape: {A[i,:].unsqueeze(1).shape} \n
A[i,:,None].shape: {A[i,:,None].shape} \n
(A[i,:,None]*B).sum(dim=0).shape: {(A[i,:,None]*B).sum(dim=0).shape}""")
B.shape: torch.Size([1000, 500]) 

A[i,:].shape: torch.Size([1000]) 

A[i,:].unsqueeze(1).shape: torch.Size([1000, 1]) 

A[i,:,None].shape: torch.Size([1000, 1]) 

(A[i,:,None]*B).sum(dim=0).shape: torch.Size([500])
  • To be convinced, uncomment the line below and experiment with A[i,:].shape; A[i,:,None].shape, (A[i,:,None]*B).shape; (A[i,:,None]*B).sum(dim=0), keeping in mind that final shape is 10*5 here
def matmul_broadcast(A,B):
  Ar, Ac = A.shape
  Br, Bc = B.shape
  C = torch.zeros(Ar, Bc)
  for i in range(Ar):
    C[i] = (A[i,:,None]*B).sum(dim=0)
    #import pdb; pdb.set_trace()
  return C
matmul_broadcast(A,B)
tensor([[ -47.17,  -54.96,  -10.88,  ...,   21.26,    4.28,   -4.78],
        [  77.24,   23.64,  -20.62,  ...,   13.62,  -26.03,   22.42],
        [-108.64,   27.16,   49.40,  ...,   27.56,    9.35,   16.46],
        [ -13.44,   45.17,   -2.30,  ...,  -79.52,  -58.32,  -13.37],
        [ -21.50,  -12.12,   55.95,  ...,   28.55,  -32.96,  -35.81]])
torch.allclose(reference,matmul_broadcast(A, B),atol=1e-04, rtol=1e-04)
True
%timeit -n 50 _=matmul_broadcast(A, B)
358 μs ± 36 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)
print(f"Speedup factor over naive matmul: {740000/180:.0f}")
Speedup factor over naive matmul: 4111

Matmul via Einstein summation

Einstein summation (einsum) is a compact representation for combining products and sums in a general way. The key rules are:

  • Repeating letters between input arrays means that values along those axes will be multiplied together.
  • Omitting a letter from the output means that values along that axis will be summed.
def matmul_einsum(A,B):
  return torch.einsum('ik,kj->ij',A,B)
%timeit -n 50 _=matmul_einsum(A,B)
The slowest run took 8.89 times longer than the fastest. This could mean that an intermediate result is being cached.
45.8 μs ± 49.8 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)
print(f"Speedup factor over naive matmul: {740000/77:.0f}")
Speedup factor over naive matmul: 9610

Default PyTorch matmul on CPU

We can use pytorch’s function or operator directly for matrix multiplication.

torch.allclose(reference, A.to('cpu')@B.to('cpu'))
True
# Warm-up run
Acpu, Bcpu = A.to('cpu'),B.to('cpu')
%timeit -n 50 _=Acpu@Bcpu
12.5 μs ± 1.02 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)
print(f"Speedup factor over naive matmul: {740000/65.6:.0f}")
Speedup factor over naive matmul: 11280
torch.cuda.is_available()
True

CUDA

  • Switching device from CPU to GPU
from pathlib import Path
import pickle, gzip, math, os, time, shutil, matplotlib as mpl, matplotlib.pyplot as plt
import torch
import numpy as np
from numba import cuda, float32
torch.cuda.is_available()
True
A = torch.randn(5,1000,device='cuda')
B = torch.randn(1000,5000,device='cuda')
reference = A@B
Ar,Ac = A.shape # n_rows * n_cols
Br,Bc = B.shape
C = torch.zeros(Ar, Bc) # will store product of A and B
C.shape
torch.Size([5, 5000])
def matmul_almost_cuda(grid,a,b,c):
  """Fills in one piece of the grid successfully"""
  i, j = grid
  if i < c.shape[0] and j < c.shape[1]:
    tmp = 0.
    for k in range(a.shape[1]): tmp += a[i,k]*b[k,j]
    c[i,j] = tmp
matmul_almost_cuda((0,0), A, B, C)
C
tensor([[-53.03,   0.00,   0.00,  ...,   0.00,   0.00,   0.00],
        [  0.00,   0.00,   0.00,  ...,   0.00,   0.00,   0.00],
        [  0.00,   0.00,   0.00,  ...,   0.00,   0.00,   0.00],
        [  0.00,   0.00,   0.00,  ...,   0.00,   0.00,   0.00],
        [  0.00,   0.00,   0.00,  ...,   0.00,   0.00,   0.00]])
  • Wiki: a compute kernel is a routine compiled for high throughput accelerators. Kernels correspond roughly to inner loops, doing a piece of the computation.
def launch_kernel(kernel,grid_x,grid_y,*args,**kwargs):
  for i in range(grid_x):
    for j in range(grid_y):
      kernel((i,j),*args,**kwargs)
  • The code below has the jist of what we want to do, but is not run in parallel
# Had to keyboard interrupt, runs nowhere near reasonable time when done sequentially
C = torch.zeros(Ar,Bc)
# grid_x <-> Ar, grid_y <-> Bc, args: A,B,C passed to matmul_cuda
launch_kernel(matmul_almost_cuda, Ar, Bc, A, B, C)
C
tensor([[-53.03,  -0.23, -24.30,  ..., -11.28, -24.68, -31.85],
        [-15.85,  13.23,  19.40,  ...,  -8.92,  28.37,  16.67],
        [ 53.55, -15.54,  13.36,  ...,  12.86,   1.21, -12.67],
        [ 37.49, -30.42,  13.84,  ..., -16.57,  24.10,  -3.87],
        [  3.07,  -0.18,   4.91,  ..., -25.19, -40.16, -32.40]])
  • To run the code in parallel, use CUDA
from numba import cuda
# Decorator below will compile into GPU code
@cuda.jit
def matmul_cuda(a,b,c):
  # numba will pass ove the grid
  i, j = cuda.grid(2)
  if i < c.shape[0] and j < c.shape[1]:
    tmp = 0.
    for k in range(a.shape[1]): tmp += a[i,k]*b[k,j]
    c[i,j] = tmp
  • Call each grid item in parallel with the number of different processors
C.shape
torch.Size([5, 5000])
TPB = 16
C = torch.zeros(Ar,Bc,device='cuda')
Cr, Cc = C.shape
blockspergrid = (math.ceil(Cr/TPB), math.ceil(Cc/TPB))
blockspergrid
(1, 313)
A.device
device(type='cuda', index=0)
#matmul_cuda[blockspergrid, (TPB, TPB)](A,B,C)

Another implementation

  • The higher-level idea is to unroll the two outer for loops into a single row-column calculation by replacing the outer loops with two dimensions of threads.
    #### Details

  • Every thread corresponds to one output element

  • Make a 2-D grid of threads to access it with a (row,column) pair.

  • Proper profiling appears to be more involved, see https://dev-discuss.pytorch.org/t/using-nsight-systems-to-profile-gpu-workload/59

from numba import cuda, float32
import numpy as np
import math

@cuda.jit
def fast_matmul(A, B, C):
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = float32(0.)
    for i in range(bpg):
        # Preload data into shared memory
        sA[ty, tx] = 0
        sB[ty, tx] = 0
        if y < A.shape[0] and (tx+i*TPB) < A.shape[1]:
          sA[ty, tx] = A[y, tx + i * TPB]
        if x < B.shape[1] and (ty+i*TPB) < B.shape[0]:
          sB[ty, tx] = B[ty + i * TPB, x]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[ty, j] * sB[j, tx]

        # Wait until all threads finish computing
        cuda.syncthreads()
    if y < C.shape[0] and x < C.shape[1]:
        C[y, x] = tmp

A_h = np.random.rand(5,1000)
B_h = np.random.rand(1000,10)
C_h = np.zeros([5,10])

A = cuda.to_device(A_h)
B = cuda.to_device(B_h)
C = cuda.to_device(C_h)

#TPB must be an integer between 1 and 32
TPB = 32
threadsperblock = (TPB, TPB)
grid_y_max = max(A_h.shape[0],B_h.shape[0])
grid_x_max = max(A_h.shape[1],B_h.shape[1])
blockspergrid_x = math.ceil(grid_x_max / threadsperblock[0])
blockspergrid_y = math.ceil(grid_y_max / threadsperblock[1])
blockspergrid = (blockspergrid_x, blockspergrid_y)

fast_matmul[blockspergrid, threadsperblock](A, B, C)
C_h = C.copy_to_host()
print(C_h)
print(A_h@B_h)

The single for loop in the code is sufficient because it iterates over the number of blocks per grid (bpg). Each block contains TPB x TPB threads, where TPB is the thread per block parameter. Each thread computes one element in the result matrix C, by performing a dot product of a row in A and a column in B. However, since the matrices A and B may be larger than the shared memory size of each block, the dot product is chunked into smaller segments of length TPB. This means that each thread needs to load multiple segments of data from the global memory to the shared memory, and accumulate the partial products in a temporary variable (tmp). The final value of tmp is then stored in the corresponding element of C.

To illustrate this, let’s assume that TPB = 2 and bpg = 2. Suppose we have the following matrices A and B:

A=[12345678910111213141516] A= ​15913​261014​371115​481216​

B=[17181920212223242526272829303132] B= ​17212529​18222630​19232731​20242832​

The result matrix C is:

C=A×B=[2502602702806186446706969861028107011121354141214701528] C=A×B= ​2506189861354​26064410281412​27067010701470​28069611121528​

The grid and block dimensions are:

dim3 dim_grid (2, 2, 1); // 2 x 2 blocks per grid dim3 dim_block (2, 2, 1); // 2 x 2 threads per block

The thread indices are:

(x, y) = (0, 0), (0, 1), (1, 0), (1, 1) // within each block (tx, ty) = (0, 0), (0, 1), (1, 0), (1, 1) // within each thread

The code will execute as follows:

For i = 0, each thread loads the first segment of data from A and B to the shared memory sA and sB. For example, the thread with (x, y) = (0, 0) and (tx, ty) = (0, 0) will load A[0, 0] and B[0, 0] to sA[0, 0] and sB[0, 0], respectively. The shared memory arrays will look like this:

sA = [[1, 2], [5, 6]] // for block (0, 0) sA = [[9, 10], [13, 14]] // for block (1, 0) sA = [[3, 4], [7, 8]] // for block (0, 1) sA = [[11, 12], [15, 16]] // for block (1, 1)

sB = [[17, 18], [21, 22]] // for block (0, 0) sB = [[17, 18], [21, 22]] // for block (0, 1) sB = [[25, 26], [29, 30]] // for block (1, 0) sB = [[25, 26], [29, 30]] // for block (1, 1)

After synchronizing the threads, each thread computes the partial product of the first segment using sA and sB. For example, the thread with (x, y) = (0, 0) and (tx, ty) = (0, 0) will compute tmp = sA[0, 0] * sB[0, 0] + sA[0, 1] * sB[1, 0] = 1 * 17 + 2 * 21 = 59. The other threads will compute similar values for their corresponding elements.

For i = 1, each thread loads the second segment of data from A and B to the shared memory sA and sB, overwriting the previous values. For example, the thread with (x, y) = (0, 0) and (tx, ty) = (0, 0) will load A[0, 2] and B[2, 0] to sA[0, 0] and sB[0, 0], respectively. The shared memory arrays will look like this:

sA = [[3, 4], [7, 8]] // for block (0, 0) sA = [[11, 12], [15, 16]] // for block (1, 0) sA = [[1, 2], [5, 6]] // for block (0, 1) sA = [[9, 10], [13, 14]] // for block (1, 1)

sB = [[25, 26], [29, 30]] // for block (0, 0) sB = [[25, 26], [29, 30]] // for block (0, 1) sB = [[17, 18], [21, 22]] // for block (1, 0) sB = [[17, 18], [21, 22]] // for block (1, 1)

After synchronizing the threads, each thread computes the partial product of the second segment using sA and sB, and adds it to the previous value of tmp. For example, the thread with (x, y) = (0, 0) and (tx, ty) = (0, 0) will compute tmp = tmp + sA[0, 0] * sB[0, 0] + sA[0, 1] * sB[1, 0] = 59 + 3 * 25 + 4 * 29 = 250. The other threads will compute similar values for their corresponding elements.

Finally, each thread stores the final value of tmp in the result matrix C.
# Warm up
for _ in range(10):
    fast_matmul[blockspergrid, threadsperblock](A, B, C)

torch.cuda.synchronize()

# Benchmark
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

times = []
for _ in range(50):
  start.record()
  fast_matmul[blockspergrid, threadsperblock](A, B, C)
  end.record()

  # Wait for all operations to finish
  torch.cuda.synchronize()

  # Append the time in microseconds
  times.append(start.elapsed_time(end) * 1000)

# Print the average time
print(f"Average time elapsed: {np.mean(times)} µs\n\
Standard deviation of time elapsed: {np.std(times)} µs")
Average time elapsed: 949.6985602378845 µs
Standard deviation of time elapsed: 33.04722920348339 µs
%timeit -n 50 fast_matmul[blockspergrid, threadsperblock](A, B, C)
72.3 µs ± 31.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Default PyTorch matmul on GPU

# Warm-up run
%timeit -n 50 _=A@B
17.9 µs ± 6.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
print(f"Speedup factor over naive matmul: {740000/18.3:.0f}")
Speedup factor over naive matmul: 40437
C.shape
torch.Size([5, 10])
TPB = 16
rr,rc = r.shape
blockspergrid = (math.ceil(rr / TPB), math.ceil(rc / TPB))
blockspergrid
(1, 1)
# matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
# r = rg.copy_to_host()
# torch.allclose(C, r)
type(C),type(r)
(torch.Tensor, numpy.ndarray)
%%timeit -n 10
matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
r = rg.copy_to_host()
/usr/local/lib/python3.10/dist-packages/numba/cuda/dispatcher.py:536: NumbaPerformanceWarning: Grid size 1 will likely result in GPU under-utilization due to low occupancy.
  warn(NumbaPerformanceWarning(msg))
The slowest run took 136.47 times longer than the fastest. This could mean that an intermediate result is being cached.
8.26 ms ± 19.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
r=(m1c@m2c).cpu()
%timeit -n 10 r=(m1c@m2c).cpu()

Our broadcasting version was >500ms, and our CUDA version is around 0.5ms, which is another 1000x improvement compared to broadcasting. So our total speedup is around 5 million times!