Weco Logo
Weco Docs

Triton Optimization

Optimize a PyTorch activation function using Triton

This example demonstrates using Weco to optimize a simple activation function implemented in PyTorch. In this example, we'll ask Weco to leverage Triton to accelerate our code.

Setup

If you haven't already, follow the Installation guide to install the Weco CLI. Otherwise, install the CLI using pip:

pip install weco

Install the required dependencies:

pip install numpy torch triton

This example requires an NVIDIA GPU.

Create the Baseline to Optimize

Create a file called module.py with the following code:

import torch
import torch.nn as nn
 
 
class Model(nn.Module):
    """
    Simple model that performs a Swish activation.
    """
    def __init__(self):
        super(Model, self).__init__()
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies Swish activation to the input tensor.
 
        Args:
            x (torch.Tensor): Input tensor of any shape.
 
        Returns:
            torch.Tensor: Output tensor with Swish applied, same shape as input.
        """
        return x * torch.sigmoid(x)

Create the Evaluation Script

Create a file called evaluate.py with the following code:

import sys
import pathlib
import importlib
import importlib.util
import traceback
import torch
import torch.nn as nn
from triton.testing import do_bench
 
 
########################################################
# Baseline
########################################################
class Model(nn.Module):
    """
    Simple model that performs a Swish activation.
    """
 
    def __init__(self):
        super(Model, self).__init__()
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies Swish activation to the input tensor.
 
        Args:
            x (torch.Tensor): Input tensor of any shape.
 
        Returns:
            torch.Tensor: Output tensor with Swish applied, same shape as input.
        """
        return x * torch.sigmoid(x)
 
 
########################################################
# Benchmark
########################################################
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
    # Clean out all old compiled extensions to prevent namespace collisions during build
    module_path = pathlib.Path(module_path)
    name = module_path.stem
    spec = importlib.util.spec_from_file_location(name, module_path)
    mod = importlib.util.module_from_spec(spec)  # type: ignore
    if add_to_sys_modules:
        sys.modules[name] = mod
    spec.loader.exec_module(mod)  # type: ignore
    return mod
 
 
def get_inputs(batch_size, dim, device):
    return torch.randn(batch_size, dim, device=device, dtype=torch.float32)
 
 
if __name__ == "__main__":
    import argparse
 
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, required=True)
    args = parser.parse_args()
 
    # benchmarking parameters
    n_correctness_trials = 10
    correctness_tolerance = 1e-5
    warmup_ms = 100
    rep_ms = 500
 
    # input parameters
    batch_size = 2000
    dim = 16384
 
    # load solution module
    try:
        torch.manual_seed(0)
        solution_module = load_module_from_path(args.path, add_to_sys_modules=False)
        solution_model = solution_module.Model().to("cuda")
        assert isinstance(solution_model, nn.Module)
    except Exception:
        print(f"Candidate module initialization failed: {traceback.format_exc()}")
        exit(1)
 
    torch.manual_seed(0)
    baseline_model = Model().to("cuda")
 
    # measure correctness
    max_diff_avg = 0
    for _ in range(n_correctness_trials):
        inputs = get_inputs(batch_size=batch_size, dim=dim, device="cuda")
        with torch.no_grad():
            optimized_output = solution_model(inputs)
            if torch.isnan(optimized_output).any():
                print("Incorrect solution: NaN detected in optimized model output")
            if torch.isinf(optimized_output).any():
                print("Incorrect solution: Inf detected in optimized model output")
            baseline_output = baseline_model(inputs)
            max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
    max_diff_avg /= n_correctness_trials
    print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
    if max_diff_avg > correctness_tolerance:
        print("Incorrect solution: max float diff is too high")
 
    # measure performance
    inputs = get_inputs(batch_size=batch_size, dim=dim, device="cuda")
    t_avg_baseline = do_bench(lambda: baseline_model(inputs), warmup=warmup_ms, rep=rep_ms)
    print(f"baseline time: {t_avg_baseline:.2f}ms")
    t_avg_optimized = do_bench(lambda: solution_model(inputs), warmup=warmup_ms, rep=rep_ms)
    print(f"optimized time: {t_avg_optimized:.2f}ms")
    print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")

Run Weco

Now run Weco to optimize your code using Triton:

weco run --source module.py \
     --eval-command "python evaluate.py --path module.py" \
     --metric speedup \
     --goal maximize \
     --steps 15 \
     --model o4-mini \
     --additional-instructions "Use a combination of triton and pytorch to optimize the forward pass while ensuring a small max float diff. Maintain the same code interface. Do not use any fallbacks. Assume any required dependencies are installed and data is already on the gpu." \
     --eval-timeout 120

Explanation

  • --source module.py: Specifies the PyTorch Swish activation implementation (module.py) that Weco will optimize.
  • --eval-command "python evaluate.py --path module.py": Defines the command to execute the evaluation script. This script benchmarks the generated solution in module.py against a baseline and outputs the speedup.
  • --metric speedup: Sets the metric Weco should focus on improving during optimization.
  • --goal maximize: Instructs Weco to aim for the highest possible speedup value.
  • --steps 15: Determines the number of optimization iterations Weco will perform.
  • --model o4-mini: Specifies the large language model to drive the optimization process.
  • --additional-instructions "...": Provides specific guidance to the LLM.
  • --eval-timeout 120: Stop running the evaluation script if it does not complete in 120 seconds.

Weco will iteratively modify module.py, incorporating Triton kernels, guided by the performance feedback (speedup) from the evaluation script and the instructions provided.

What's Next?

On this page