Skip to content
/ kernels Public

A library of transformer kernels written in Python PyTorch and Triton plus small scale transformer scaling law experiments.

Notifications You must be signed in to change notification settings

samgd/kernels

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Kernels

A library of transformer kernels written in Python PyTorch and Triton plus small scale transformer scaling law experiments.

MatMul

Triton

MatMul Speed

Attention

FlashAttention-2

Paper: link

Triton

Forward

Attention Forward Speed

Backward

Attention Backward Speed

Normalization

RMSNorm

Paper: link

Triton

Forward

RMSNorm Forward Speed

The Triton forwards pass loads the input vector with shape [batch_size, n_tokens, hidden_size] once to compute RMS. It then loads the weight vector of size [hidden_size,] and input vector again to compute the normalized output. The normalized output is written back to memory. The total bytes read and written is element_size * 3*batch_size*n_tokens*hidden_size + hidden_size. This can be divided by the execution time from the previous plot to compute bandwidth. The 3090 has a maximum bandwidth of ~936GB/s so the achieved bandwidth for smaller hidden sizes must be from hitting cache:

RMSNorm Forward Bandwidth

Backward

RMSNorm Backward Speed

Position Embedding

Rotary Position Embedding (RoPE)

Paper: link

Triton

Forward

RoPE Forward Speed

The Triton forward pass kernel launches one program (CUDA block/CTA) per [batch, seq_len, n_head] and each program applies RoPE over the head_dim. The [batch, seq_len, n_head, head_dim] input is loaded and stored from global memory in the native data type (e.g. bfloat16). The cosine and sine arrays used in the rotation have shape [seq_len, head_dim // 2] and data type float32. Each program loads the [1, head_dim // 2] slice at the corresponding seq_len. The total cosine and sine loads is therefore 2*[batch, seq_len, n_head, head_dim // 2] however in practice these arrays fit within cache so global memory loads are limited.

The lower bound on the plot assumes perfect caching so no cosine nor sine global memory loads. The upper bound assumes no caching so each program loads the cosine and sine data it needs from global memory.

RoPE Forward Bandwidth

Backward

The forward pass of RoPE splits the head_dim into pairs of values and rotates each pair before concatenating them back together. The backwards pass "un-rotates" pairs in the gradient with respect to the output to get the gradient with respect to the input. This "un-rotation" is a fowards pass with the gradient with respect to the output as input and the cached sine rotation array multiplied by -1.

RoPE Backward Speed

Losses

Cross Entropy

Triton

Forward

The "Two Pass" version was an earlier implementation that did a first pass over the logits to compute the maximum value followed by a second over them to compute the logsumexp. The "Online" version is the current implementation uses the online log-sum-exp trick (similar to FlashAttention) and requires only one pass over the logits.

RMSNorm Forward Speed

Backward

RMSNorm Backward Speed

Activation Functions

Sigmoid

Triton

Forward

Sigmoid Forward Speed

Backward

Sigmoid Backward Speed

Swish

Triton

Forward

Swish Forward Speed

Backward

Swish Backward Speed

Scaling Law Experiments

IsoFLOPs

IsoFLOPs

About

A library of transformer kernels written in Python PyTorch and Triton plus small scale transformer scaling law experiments.

Resources

Stars

Watchers

Forks