Weco LogoWeco Docs

CUDA Optimization

Optimize a PyTorch self-attention module using custom CUDA kernels

This example showcases using Weco to optimize a PyTorch causal multi-head self-attention implementation by generating custom CUDA kernels. This approach aims for low-level optimization beyond standard PyTorch or even Triton for potentially higher performance on NVIDIA GPUs. This example uses a separate Markdown file (guide.md) to provide detailed instructions and context to the LLM.

You can find the complete files for this example here.

Setup

If you haven't already, follow the Installation guide to install the Weco CLI. Otherwise, install the CLI using pip:

pip install weco

Google AI Studio has a free API usage quota. Create a key here to use weco for free.

export GEMINI_API_KEY="your_key_here"

Install the required dependency:

pip install torch

Note: This example requires a compatible NVIDIA GPU and the CUDA Toolkit installed on your system for compiling and running the generated CUDA code.

Create the Baseline to Optimize

Create a file called optimize.py with the following content:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
# ref: https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/model.py#L29
 
class Model(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    """
 
    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        # output projection
        self.c_proj = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
        self.n_head = n_head
        self.n_embd = n_embd
 
    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
 
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

Create the Evaluation Script

Create a file called evaluate.py with the following content:

import sys
import os
import pathlib
import importlib
import importlib.util
import traceback
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
 
########################################################
# Baseline
########################################################
class Model(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    """
 
    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        # output projection
        self.c_proj = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
        self.n_head = n_head
        self.n_embd = n_embd
 
    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
 
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y
 
 
########################################################
# Weco Solution
########################################################
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
    # Clean out all old compiled extensions to prevent namespace collisions during build
    module_path = pathlib.Path(module_path)
    name = module_path.stem
    spec = importlib.util.spec_from_file_location(name, module_path)
    mod = importlib.util.module_from_spec(spec)  # type: ignore
    if add_to_sys_modules:
        sys.modules[name] = mod
    spec.loader.exec_module(mod)  # type: ignore
    return mod
 
 
########################################################
# Benchmark
########################################################
os.environ["MAX_JOBS"] = "1"  # number of workers for building with ninja
 
 
def get_inputs(batch_size, seq_len, n_embd, device):
    return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
 
 
@torch.no_grad()
def bench(f, inputs, n_warmup, n_rep):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
 
    # warmup
    for _ in range(n_warmup):
        f(inputs)  # noqa
    torch.cuda.synchronize()
 
    # benchmark
    t_avg_ms = 0.0
    for _ in range(n_rep):
        # time the forward pass
        start_event.record()
        f(inputs)
        end_event.record()
        # wait for all computations to complete
        torch.cuda.synchronize()
        t_avg_ms += start_event.elapsed_time(end_event)
    return t_avg_ms / n_rep
 
 
if __name__ == "__main__":
    import argparse
 
    parser = argparse.ArgumentParser()
    parser.add_argument("--solution-path", type=str, required=True)
    args = parser.parse_args()
 
    # benchmarking parameters
    n_correctness_trials = 10
    n_warmup = 1000
    n_rep = 5000
 
    # init parameters
    max_seqlen = 512
    seq_len = 256
    n_embd = 768
    n_head = 8
    # turn off dropout to measure correctness
    attn_pdrop = 0.0
    resid_pdrop = 0.0
 
    # input parameters
    batch_size = 32
 
    # load solution module
    try:
        torch.manual_seed(0)
        solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
        solution_model = solution_module.Model(
            n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
        ).to("cuda")
        assert isinstance(solution_model, nn.Module)
    except Exception:
        print(f"Candidate module initialization failed: {traceback.format_exc()}")
        exit(1)
 
    torch.manual_seed(0)
    baseline_model = Model(
        n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
    ).to("cuda")
 
    # measure correctness
    max_diff_avg = 0
    for _ in range(n_correctness_trials):
        inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
        with torch.no_grad():
            baseline_output = baseline_model(inputs)
            optimized_output = solution_model(inputs)
            max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
    max_diff_avg /= n_correctness_trials
    print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
 
    # measure performance
    inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
    t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
    print(f"baseline time: {t_avg_baseline:.2f}ms")
    t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
    print(f"optimized time: {t_avg_optimized:.2f}ms")
    print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")

Create the Guidance File

Create a file called guide.md with the following content:

# Writing In-line CUDA Kernels: 101
 
This document outlines the strategy to improve speedup by writing fused and optimized CUDA kernels using a single-file implementation.
 
## Requirements
 
- **Single-File Implementation:** Develop fused CUDA kernels within one file.
- **No Fallback Implementation:** Do not include any alternative or fallback code.
- **Simplicity & Readability:** Write simple, easy-to-understand code and include clear comments.
- **Avoid Templates:** Use plain fused kernel functions without templates.
- **Multiple Kernels Allowed:** You can define more than one kernel in the file if needed.
- **Model Class Requirement:** The solution must include a class `Model` (an instance of `nn.Module`), with the main computation in its `forward` method.
- **Preserve Initialization:** Do not change the initialization of the `Model` class.
- **Focus on Efficiency:** Concentrate solely on efficient PyTorch and CUDA coding without capturing logs.
- **Error Handling:** Any terminal output or errors will be reviewed by an LLM for feedback.
 
## Baseline Code
 
The baseline implementation of the `Model` class simply performs an element-wise addition.
 
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
 
    def forward(self, a, b):
        return a + b

Optimized Code

The optimized version employs a custom CUDA kernel for fused element-wise addition. The kernel is defined and compiled inline using PyTorch's load_inline.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
 
# Define the custom CUDA kernel for element-wise addition
elementwise_add_source = '''
#include <torch/extension.h>
#include <cuda_runtime.h>
 
// CUDA kernel for element-wise addition
__global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size) {
        out[idx] = a[idx] + b[idx];
    }
}
 
// Launch function for the CUDA kernel
torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b) {
    auto size = a.numel();
    auto out = torch::zeros_like(a);
    const int block_size = 256;
    const int num_blocks = (size + block_size - 1) / block_size;
    elementwise_add_kernel<<<num_blocks, block_size>>>(a.data_ptr<float>(), b.data_ptr<float>(), out.data_ptr<float>(), size);
    return out;
}
'''
 
# C++ function prototype declaration
elementwise_add_cpp_source = "torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b);"
 
# Compile the inline CUDA code for element-wise addition
elementwise_add = load_inline(
    name="elementwise_add",
    cpp_sources=elementwise_add_cpp_source,
    cuda_sources=elementwise_add_source,
    functions=["elementwise_add_cuda"],
    verbose=True,
    extra_cflags=[""],
    extra_ldflags=[""],
)
 
class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.elementwise_add = elementwise_add
 
    def forward(self, a, b):
        return self.elementwise_add.elementwise_add_cuda(a, b)

Run Weco

Now run Weco to optimize your code:

weco run --source optimize.py \
         --eval-command "python evaluate.py --solution-path optimize.py" \
         --metric speedup \
         --maximize true \
         --steps 30 \
         --model gemini-2.5-pro-exp-03-25 \
         --additional-instructions guide.md

Explanation

  • --source optimize.py: The initial PyTorch self-attention code to be optimized with CUDA.
  • --eval-command "python evaluate.py --solution-path optimize.py": Runs the evaluation script, which compiles (if necessary) and benchmarks the CUDA-enhanced code in optimize.py against a baseline, printing the speedup.
  • --metric speedup: The optimization target metric.
  • --maximize true: Weco aims to increase the speedup.
  • --steps 30: The number of optimization iterations.
  • --model gemini-2.5-pro-exp-03-25: The LLM used for code generation.
  • --additional-instructions guide.md: Points Weco to the guidance file created above.

Weco will iteratively modify optimize.py, potentially generating and integrating CUDA C++ code, guided by the evaluation results and the instructions in guide.md.

For more examples, visit the Examples section.

On this page