Weco Logo
Weco Platform Docs

CUDA Optimization

Optimize a PyTorch self-attention module using custom CUDA kernels

This example showcases using Weco to optimize a PyTorch causal multi-head self-attention implementation by generating custom CUDA kernels. This approach aims for low-level optimization beyond standard PyTorch or even Triton for potentially higher performance on NVIDIA GPUs.

You can find the complete files for this example here.

Setup

If you haven't already, follow the Installation guide to install the Weco CLI. Otherwise, install the CLI using pip:

pip install weco

Install the required dependencies:

pip install torch ninja triton

Note: This example requires a compatible NVIDIA GPU and the CUDA Toolkit installed on your system for compiling and running the generated CUDA code.

Create the Baseline to Optimize

Create a file called optimize.py with the following content:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
# ref: https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/model.py#L29
 
class Model(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    """
 
    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        # output projection
        self.c_proj = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
        self.n_head = n_head
        self.n_embd = n_embd
 
    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
 
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

Create the Evaluation Script

Create a file called evaluate.py with the following content:

import sys
import os
import pathlib
import importlib
import importlib.util
import traceback
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from triton.testing import do_bench
 
 
########################################################
# Baseline
########################################################
class Model(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    """
 
    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        # output projection
        self.c_proj = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
        self.n_head = n_head
        self.n_embd = n_embd
 
    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
 
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y
 
 
########################################################
# Weco Solution
########################################################
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
    # Clean out all old compiled extensions to prevent namespace collisions during build
    module_path = pathlib.Path(module_path)
    name = module_path.stem
    spec = importlib.util.spec_from_file_location(name, module_path)
    mod = importlib.util.module_from_spec(spec)  # type: ignore
    if add_to_sys_modules:
        sys.modules[name] = mod
    spec.loader.exec_module(mod)  # type: ignore
    return mod
 
 
########################################################
# Benchmark
########################################################
os.environ["MAX_JOBS"] = "1"  # number of workers for building with ninja
 
 
def get_inputs(batch_size, seq_len, n_embd, device):
    return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
 
 
if __name__ == "__main__":
    import argparse
 
    parser = argparse.ArgumentParser()
    parser.add_argument("--solution-path", type=str, required=True)
    args = parser.parse_args()
 
    # benchmarking parameters
    n_correctness_trials = 10
    correctness_tolerance = 1e-5
    warmup_ms = 1e3
    rep_ms = 5 * 1e3
 
    # init parameters
    max_seqlen = 512
    seq_len = 256
    n_embd = 768
    n_head = 8
    # turn off dropout to measure correctness
    attn_pdrop = 0.0
    resid_pdrop = 0.0
 
    # input parameters
    batch_size = 32
 
    # load solution module
    try:
        torch.manual_seed(0)
        solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
        solution_model = solution_module.Model(
            n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
        ).to("cuda")
        assert isinstance(solution_model, nn.Module)
    except Exception:
        print(f"Candidate module initialization failed: {traceback.format_exc()}")
        exit(1)
 
    torch.manual_seed(0)
    baseline_model = Model(
        n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
    ).to("cuda")
 
    # measure correctness
    max_diff_avg = 0
    for _ in range(n_correctness_trials):
        inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
        with torch.no_grad():
            optimized_output = solution_model(inputs)
            if torch.isnan(optimized_output).any():
                print("Incorrect solution: NaN detected in optimized model output")
            if torch.isinf(optimized_output).any():
                print("Incorrect solution: Inf detected in optimized model output")
            baseline_output = baseline_model(inputs)
            max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
    max_diff_avg /= n_correctness_trials
    print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
    if max_diff_avg > correctness_tolerance:
        print("Incorrect solution: max float diff is too high")
 
    # measure performance
    inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
    t_avg_baseline = do_bench(lambda: baseline_model(inputs), warmup=warmup_ms, rep=rep_ms)
    print(f"baseline time: {t_avg_baseline:.2f}ms")
    t_avg_optimized = do_bench(lambda: solution_model(inputs), warmup=warmup_ms, rep=rep_ms)
    print(f"optimized time: {t_avg_optimized:.2f}ms")
    print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")

Run Weco

Now run Weco to optimize your code:

weco run --source optimize.py \
     --eval-command "python evaluate.py --solution-path optimize.py" \
     --metric speedup \
     --goal maximize \
     --steps 50 \
     --model o4-mini \
     --additional-instructions "Write in-line CUDA using pytorch's load_inline() to optimize the code while ensuring a small max float diff. Maintain the same code format. Do not use any fallbacks. Assume any required dependencies are installed and data is already on the gpu."

Explanation

  • --source optimize.py: The initial PyTorch self-attention code to be optimized with CUDA.
  • --eval-command "python evaluate.py --solution-path optimize.py": Runs the evaluation script, which compiles (if necessary) and benchmarks the CUDA-enhanced code in optimize.py against a baseline, printing the speedup.
  • --metric speedup: The optimization target metric.
  • --goal maximize: Weco aims to increase the speedup.
  • --steps 50: The number of optimization iterations.
  • --model o4-mini: The LLM used for code generation.
  • --additional-instructions "...": Provides guidance to the LLM on the optimization approach.

Weco will iteratively modify optimize.py, potentially generating and integrating CUDA C++ code, guided by the evaluation results and the additional instructions provided.

What's Next?

On this page