Tiling and reuse: matrix multiplication#

Matrix multiplication is one of the most important GPU workloads as it supports many computational problems. With the increased popularity of deep neural networks (DNNs), matrix multiplication has become vital and is found in the most computationally-intensive layers (i.e., fully connected and convolutional) of neural networks. In some cases, you might want to convert a problem to use matrix multiplications because GPUs are optimized to efficiently execute matrix multiplication, and to outperform an iterative implementation of the original problem.

Matrix multiplication fundamentals#

A matrix multiplication operation involves three matrices, A, B, and C. Matrices A and B are input matrices, and matrix C is the output matrix. You can denote matrix A as having m rows and k columns and matrix B as having k rows and n columns. Matrix multiplication operations require the number of columns in A to equal the number of rows in B. Output matrix C then inherits m rows and n columns. Each element in C is the sum of the k multiplications of each element in the corresponding row of A and column of B, as formally represented in the figure below:

Diagram demonstrating matrix multiplication

An iterative CPU implementation will apply three nested loops that traverse through m, n, and k. Each iteration will perform single multiplication and accumulation operations. Therefore, theoretically, a total of 2 × m × n × k operations are required.

A GPU implementation by comparison is embarrassingly parallel. The GPU kernel parallelizes the m and n loops by calculating each element in the output matrix using a thread. Thus, you need a total of m × n threads, with each traversing a loop of k iterations.

Naive implementation and memory access patterns#

Although the naive implementation produces the correct result, the resulting performance is not ideal. Profiling the kernel execution shows that it reads from the GPU’s Dynamic Random Access Memory (DRAM) several times more often than the combined size of the input matrices. To further explore this issue, you must examine how the threads read the input matrices.

In the naive implementation, each thread in the kernel reads an entire row in matrix A and a column in matrix B. The same data item must be accessed by multiple threads from different wavefronts. In the naive implementation, all memory reads are handled by the main memory system, and although there are repeated memory references made to the same location, these memory accesses should instead be serviced by the cache. However, given the limited cache space and the large number of threads involved, this places inordinate memory pressure on the caches. Then, as you increase the matrix size, the caches will become incapable of accommodating the data and memory reads, which will quickly push the data to the main memory, causing significant delays.

Ideally, you want to have more control over caches. When data are loaded into the cache, it should be reused by as many threads as possible before being released. In CPU programming, caches are fully transparent to the programmer and are typically not under their control. In contrast, GPUs provide Local Data Store (LDS) memory, which is effectively a programmable L1 cache.

 1// Naive matrix multiplication kernel
 2__global__ void matrix_multiply_naive(float *a, float *b, float *out, int m, int n, int k)
 3{
 4    int gid_x = blockDim.x * blockIdx.x + threadIdx.x;
 5    int gid_y = blockDim.y * blockIdx.y + threadIdx.y;
 6
 7    if (gid_x < n && gid_y < m)
 8    {
 9        float sum = 0.0f;
10        for (int i = 0; i < k; ++i)
11        {
12            sum += a[gid_y * k + i] * b[i * n + gid_x];
13        }
14        out[gid_y * n + gid_x] = sum;
15    }
16}

LDS tiling optimization#

To improve matrix multiplication performance, and make better use of LDS, you apply a common GPU programming technique called “tiling.” Tiling methods use all threads from a workgroup to collectively process parts of the data before moving to the next portion of the problem. The first workgroup loads the data of matrices A and B in a tile to the LDS. Then, matrix multiplication is applied to the loaded tile. Next, the wavefronts move to the next tile, continuing to accumulate multiplication results.

The LDS tiling-based kernel implementation uses LDS memory to dramatically reduce global memory access. The kernel’s interface is identical to the native implementation. The only special requirement is that the tile size parameter must be defined to determine the number of elements in each of the two dimensions to group into a single tile. The workgroup size must match the tile size to help ensure the correctness of the output.

 1// LDS tiling parameters
 2constexpr int base_tile_size = 16; // LDS tile dimension
 3
 4// LDS tiling kernel
 5__global__ void matrix_multiply_lds_tiling(float *a, float *b, float *out, int m, int n, int k)
 6{
 7    __shared__ float tilea[base_tile_size][base_tile_size];
 8    __shared__ float tileb[base_tile_size][base_tile_size];
 9
10    int tx = threadIdx.x;
11    int ty = threadIdx.y;
12
13    int row = blockIdx.y * base_tile_size + ty;
14    int col = blockIdx.x * base_tile_size + tx;
15
16    float sum = 0.0f;
17
18    int numtiles = (k + base_tile_size - 1) / base_tile_size;
19    for (int t = 0; t < numtiles; t++)
20    {
21        int acol = t * base_tile_size + tx;
22        if (row < m && acol < k)
23        {
24            tilea[ty][tx] = a[row * k + acol];
25        }
26        else
27        {
28            tilea[ty][tx] = 0.0f;
29        }
30
31        int brow = t * base_tile_size + ty;
32        if (brow < k && col < n)
33        {
34            tileb[ty][tx] = b[brow * n + col];
35        }
36        else
37        {
38            tileb[ty][tx] = 0.0f;
39        }
40
41        __syncthreads();
42
43        for (int i = 0; i < base_tile_size; i++)
44        {
45            sum += tilea[ty][i] * tileb[i][tx];
46        }
47
48        __syncthreads();
49    }
50
51    if (row < m && col < n)
52    {
53        out[row * n + col] = sum;
54    }
55}

The kernel implementation must be modified to use two loops, rather than just one. The outer loop addresses memory across the tile, and the inner loop accumulates multiplications. Before the loop starts, you allocate two buffers for the tiles for matrices A and B. In each iteration of the outer loop, you load the data from the main memory to the LDS. Here, the complexity of the implementation mainly comes from the index calculation and boundary checking. Next, you use the inner loop in a way similar to the naive implementation to perform multiplication and accumulation. Note that you need barriers before and after the inner loop to guarantee that all data are loaded and used, respectively. Finally, you store the final result in the output matrix.

LDS tiling is an effective way to reduce DRAM access and improve performance. Performance improvements with tiling depend on the specific GPU architecture and matrix dimensions, but significant speedups over the naive implementation are commonly observed on AMD GPUs.

Register tiling: The next optimization step#

After implementing LDS tiling, the next optimization step is register tiling. While LDS tiling reduces global memory traffic by caching tiles in shared memory, register tiling further reduces memory traffic by keeping frequently accessed data in thread-local registers. This approach eliminates the need for many shared memory reads within the inner computation loop.

Register tiling works by having each thread compute multiple output elements and load the necessary input data into registers. The computation then proceeds using only register-to-register operations, with minimal shared memory access. This technique is particularly effective because register access is significantly faster than shared memory access, each thread has private registers eliminating bank conflicts, synchronization overhead between threads is reduced, and better instruction-level parallelism can be achieved.

The register tiling kernel uses a more complex thread-to-data mapping where each thread computes a small tile of output elements (typically 4 × 4) and loads the required input data into registers before performing the multiplication.

  1// Register tiling parameters
  2constexpr int thread_tile_m = 4; // Each thread computes 4x4 output
  3constexpr int thread_tile_n = 4;
  4constexpr int warp_threads_n = thread_tile_n; // 4
  5constexpr int block_warps_m = 2;
  6constexpr int block_warps_n = 2;
  7constexpr int k_tile_size = 16; // K-dimension tile size
  8
  9// Register tiling kernel
 10__global__ void matrix_multiply_register_tiling(float *a, float *b, float *out, int m, int n, int k)
 11{
 12    constexpr int warp_threads_m = warp_size / warp_threads_n;
 13    constexpr int warp_tile_m = warp_threads_m * thread_tile_m;
 14    constexpr int warp_tile_n = warp_threads_n * thread_tile_n;
 15    constexpr int block_threads = warp_size * block_warps_m * block_warps_n;
 16    constexpr int block_tile_m = block_warps_m * warp_tile_m;
 17    constexpr int block_tile_n = block_warps_n * warp_tile_n;
 18
 19    __shared__ float tilea[block_tile_m][k_tile_size + 4];
 20    __shared__ float tileb[k_tile_size][block_tile_n];
 21
 22    int tid = threadIdx.y * blockDim.x + threadIdx.x;
 23    int warp_id = tid / warp_size;
 24    int lane_id = tid % warp_size;
 25
 26    int warp_row = warp_id / block_warps_n;
 27    int warp_col = warp_id % block_warps_n;
 28
 29    int lane_row = lane_id / warp_threads_n;
 30    int lane_col = lane_id % warp_threads_n;
 31
 32    int thread_row_in_block = warp_row * warp_tile_m + lane_row * thread_tile_m;
 33    int thread_col_in_block = warp_col * warp_tile_n + lane_col * thread_tile_n;
 34
 35    int block_row_start = blockIdx.y * block_tile_m;
 36    int block_col_start = blockIdx.x * block_tile_n;
 37
 38    float regc[thread_tile_m][thread_tile_n] = {0.0f};
 39
 40    int numtiles = (k + k_tile_size - 1) / k_tile_size;
 41    for (int t = 0; t < numtiles; t++)
 42    {
 43        int elements_to_load_a = (block_tile_m * k_tile_size) / block_threads;
 44        for (int i = 0; i < elements_to_load_a; i++)
 45        {
 46            int idx = tid + i * block_threads;
 47            int tile_m = idx / k_tile_size;
 48            int tile_n = idx % k_tile_size;
 49
 50            int global_m = block_row_start + tile_m;
 51            int global_n = t * k_tile_size + tile_n;
 52
 53            if (global_m < m && global_n < k && tile_m < block_tile_m)
 54            {
 55                tilea[tile_m][tile_n] = a[global_m * k + global_n];
 56            }
 57            else if (tile_m < block_tile_m)
 58            {
 59                tilea[tile_m][tile_n] = 0.0f;
 60            }
 61        }
 62
 63        int elements_to_load_b = (k_tile_size * block_tile_n) / block_threads;
 64        for (int i = 0; i < elements_to_load_b; i++)
 65        {
 66            int idx = tid + i * block_threads;
 67            int tile_m = idx / block_tile_n;
 68            int tile_n = idx % block_tile_n;
 69
 70            int global_m = t * k_tile_size + tile_m;
 71            int global_n = block_col_start + tile_n;
 72
 73            if (global_m < k && global_n < n && tile_m < k_tile_size)
 74            {
 75                tileb[tile_m][tile_n] = b[global_m * n + global_n];
 76            }
 77            else if (tile_m < k_tile_size)
 78            {
 79                tileb[tile_m][tile_n] = 0.0f;
 80            }
 81        }
 82
 83        __syncthreads();
 84
 85        for (int kk = 0; kk < k_tile_size; kk++)
 86        {
 87            float rega[thread_tile_m];
 88            for (int i = 0; i < thread_tile_m; i++)
 89            {
 90                rega[i] = tilea[thread_row_in_block + i][kk];
 91            }
 92
 93            float regb[thread_tile_n];
 94            for (int j = 0; j < thread_tile_n; j++)
 95            {
 96                regb[j] = tileb[kk][thread_col_in_block + j];
 97            }
 98
 99            for (int i = 0; i < thread_tile_m; i++)
100            {
101                for (int j = 0; j < thread_tile_n; j++)
102                {
103                    regc[i][j] += rega[i] * regb[j];
104                }
105            }
106        }
107
108        __syncthreads();
109    }
110
111    for (int i = 0; i < thread_tile_m; i++)
112    {
113        for (int j = 0; j < thread_tile_n; j++)
114        {
115            int global_row = block_row_start + thread_row_in_block + i;
116            int global_col = block_col_start + thread_col_in_block + j;
117
118            if (global_row < m && global_col < n)
119            {
120                out[global_row * n + global_col] = regc[i][j];
121            }
122        }
123    }
124}

The register tiling implementation introduces thread tile computation, where each thread computes multiple output elements (defined by thread_tile_m and thread_tile_n parameters), and register-based accumulation, where intermediate results are stored in registers instead of shared memory. Threads also use efficient data loading, loading data cooperatively into shared memory and then copying to registers for computation.

This approach typically provides additional performance improvements beyond LDS tiling. The actual speedup will vary based on GPU architecture, matrix dimensions, and other factors.

Practical implementation considerations#

When implementing these optimizations for your specific application, consider the following factors:

  • Matrix dimensions: The effectiveness of tiling depends on how well the matrix dimensions align with tile sizes and GPU compute unit resources

  • Memory requirements: Larger tiles may not fit in available LDS space, requiring careful tuning of tile dimensions

  • GPU architecture: Different AMD GPU architectures have varying LDS capacities and register availability, affecting optimal tile sizes

  • Workload characteristics: Some applications may benefit more from LDS tiling alone, while others justify the added complexity of register tiling

Choosing the right optimization approach depends on your specific performance requirements, target hardware, and implementation complexity constraints.