Triton Optimization
Optimize a PyTorch activation function using Triton
This example demonstrates using Weco to optimize a simple activation function implemented in PyTorch. In this example, we'll ask Weco to leverage Triton to accelerate our code.
Setup
If you haven't already, follow the Installation guide to install the Weco CLI. Otherwise, install the CLI:
curl -fsSL https://weco.ai/install.sh | shpowershell -ExecutionPolicy ByPass -c "irm https://weco.ai/install.ps1 | iex"irm https://weco.ai/install.ps1 | iexpip install wecogit clone https://github.com/wecoai/weco-cli.gitcd weco-clipip install -e .Install the required dependencies:
pip install numpy torch tritonThis example requires an NVIDIA GPU.
Create the Baseline to Optimize
Create a file called module.py with the following code:
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a Swish activation.
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies Swish activation to the input tensor.
Args:
x (torch.Tensor): Input tensor of any shape.
Returns:
torch.Tensor: Output tensor with Swish applied, same shape as input.
"""
return x * torch.sigmoid(x)Create the Evaluation Script
Create a file called evaluate.py with the following code:
import sys
import pathlib
import importlib
import importlib.util
import traceback
import torch
import torch.nn as nn
from triton.testing import do_bench
########################################################
# Baseline
########################################################
class Model(nn.Module):
"""
Simple model that performs a Swish activation.
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies Swish activation to the input tensor.
Args:
x (torch.Tensor): Input tensor of any shape.
Returns:
torch.Tensor: Output tensor with Swish applied, same shape as input.
"""
return x * torch.sigmoid(x)
########################################################
# Benchmark
########################################################
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
# Clean out all old compiled extensions to prevent namespace collisions during build
module_path = pathlib.Path(module_path)
name = module_path.stem
spec = importlib.util.spec_from_file_location(name, module_path)
mod = importlib.util.module_from_spec(spec) # type: ignore
if add_to_sys_modules:
sys.modules[name] = mod
spec.loader.exec_module(mod) # type: ignore
return mod
def get_inputs(batch_size, dim, device):
return torch.randn(batch_size, dim, device=device, dtype=torch.float32)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
# benchmarking parameters
n_correctness_trials = 10
correctness_tolerance = 1e-5
warmup_ms = 100
rep_ms = 500
# input parameters
batch_size = 2000
dim = 16384
# load solution module
try:
torch.manual_seed(0)
solution_module = load_module_from_path(args.path, add_to_sys_modules=False)
solution_model = solution_module.Model().to("cuda")
assert isinstance(solution_model, nn.Module)
except Exception:
print(f"Candidate module initialization failed: {traceback.format_exc()}")
exit(1)
torch.manual_seed(0)
baseline_model = Model().to("cuda")
# measure correctness
max_diff_avg = 0
for _ in range(n_correctness_trials):
inputs = get_inputs(batch_size=batch_size, dim=dim, device="cuda")
with torch.no_grad():
optimized_output = solution_model(inputs)
if torch.isnan(optimized_output).any():
print("Incorrect solution: NaN detected in optimized model output")
if torch.isinf(optimized_output).any():
print("Incorrect solution: Inf detected in optimized model output")
baseline_output = baseline_model(inputs)
max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
max_diff_avg /= n_correctness_trials
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
if max_diff_avg > correctness_tolerance:
print("Incorrect solution: max float diff is too high")
# measure performance
inputs = get_inputs(batch_size=batch_size, dim=dim, device="cuda")
t_avg_baseline = do_bench(lambda: baseline_model(inputs), warmup=warmup_ms, rep=rep_ms)
print(f"baseline time: {t_avg_baseline:.2f}ms")
t_avg_optimized = do_bench(lambda: solution_model(inputs), warmup=warmup_ms, rep=rep_ms)
print(f"optimized time: {t_avg_optimized:.2f}ms")
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")Run Weco
Now run Weco to optimize your code using Triton:
weco run --source module.py \
--eval-command "python evaluate.py --path module.py" \
--metric speedup \
--goal maximize \
--steps 15 \
--model o4-mini \
--additional-instructions "Use a combination of triton and pytorch to optimize the forward pass while ensuring a small max float diff. Maintain the same code interface. Do not use any fallbacks. Assume any required dependencies are installed and data is already on the gpu." \
--eval-timeout 120weco run --source module.py ^
--eval-command "python evaluate.py --path module.py" ^
--metric speedup ^
--goal maximize ^
--steps 15 ^
--model o4-mini ^
--additional-instructions "Use a combination of triton and pytorch to optimize the forward pass while ensuring a small max float diff. Maintain the same code interface. Do not use any fallbacks. Assume any required dependencies are installed and data is already on the gpu." ^
--eval-timeout 120Or in PowerShell:
weco run --source module.py `
--eval-command "python evaluate.py --path module.py" `
--metric speedup `
--goal maximize `
--steps 15 `
--model o4-mini `
--additional-instructions "Use a combination of triton and pytorch to optimize the forward pass while ensuring a small max float diff. Maintain the same code interface. Do not use any fallbacks. Assume any required dependencies are installed and data is already on the gpu." `
--eval-timeout 120Explanation
--source module.py: Specifies the PyTorch Swish activation implementation (module.py) that Weco will optimize.--eval-command "python evaluate.py --path module.py": Defines the command to execute the evaluation script. This script benchmarks the generated solution inmodule.pyagainst a baseline and outputs thespeedup.--metric speedup: Sets the metric Weco should focus on improving during optimization.--goal maximize: Instructs Weco to aim for the highest possible speedup value.--steps 15: Determines the number of optimization iterations Weco will perform.--model o4-mini: Specifies the large language model to drive the optimization process.--additional-instructions "...": Provides specific guidance to the LLM.--eval-timeout 120: Stop running the evaluation script if it does not complete in 120 seconds.
Weco will iteratively modify module.py, incorporating Triton kernels, guided by the performance feedback (speedup) from the evaluation script and the instructions provided.
What's Next?
- Lower-level GPU programming: Try CUDA Optimization for maximum performance control
- Different optimization types: Explore Model Development or Prompt Engineering
- Simpler GPU optimization: Start with PyTorch Optimization
- Better evaluation scripts: Learn Writing Good Evaluation Scripts
- All command options: Check the CLI Reference