PyTorch Optimization
Optimize the speed of simple operation in PyTorch
You can:
- Follow along here
- Checkout the files from GitHub
- Run on Google Colab (recommended)
Setup
If you haven't already, follow the Installation guide to install the Weco CLI. Otherwise, install the CLI:
curl -fsSL https://weco.ai/install.sh | shpowershell -ExecutionPolicy ByPass -c "irm https://weco.ai/install.ps1 | iex"irm https://weco.ai/install.ps1 | iexpip install wecoWe recommend using a virtual environment when installing with pip.
git clone https://github.com/wecoai/weco-cli.gitcd weco-clipip install -e .Use this if you want to contribute or modify Weco.
Install the dependencies of the scripts shown in subsequent sections.
pip install torchCreate the Baseline to Optimize
Create a file called module.py with the following code:
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a matrix multiplication, division, summation, and scaling.
"""
def __init__(self, input_size, hidden_size, scaling_factor):
super(Model, self).__init__()
self.weight = nn.Parameter(torch.randn(hidden_size, input_size))
self.scaling_factor = scaling_factor
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_size).
Returns:
torch.Tensor: Output tensor of shape (batch_size, hidden_size).
"""
x = torch.matmul(x, self.weight.T)
x = x / 2
x = torch.sum(x, dim=1, keepdim=True)
x = x * self.scaling_factor
return xCreate the Evaluation Script
Create a file called evaluate.py with the following code:
import time
import sys
import os
import pathlib
import importlib
import traceback
import torch
import torch.nn as nn
########################################################
# Baseline
########################################################
class Model(nn.Module):
"""
Model that performs a matrix multiplication, division, summation, and scaling.
"""
def __init__(self, input_size, hidden_size, scaling_factor):
super(Model, self).__init__()
self.weight = nn.Parameter(torch.randn(hidden_size, input_size))
self.scaling_factor = scaling_factor
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_size).
Returns:
torch.Tensor: Output tensor of shape (batch_size, hidden_size).
"""
x = torch.matmul(x, self.weight.T)
x = x / 2
x = torch.sum(x, dim=1, keepdim=True)
x = x * self.scaling_factor
return x
########################################################
# Weco Solution
########################################################
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
# Clean out all old compiled extensions to prevent namespace collisions during build
module_path = pathlib.Path(module_path)
name = module_path.stem
spec = importlib.util.spec_from_file_location(name, module_path)
mod = importlib.util.module_from_spec(spec) # type: ignore
if add_to_sys_modules:
sys.modules[name] = mod
spec.loader.exec_module(mod) # type: ignore
return mod
########################################################
# Benchmark
########################################################
os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja
def get_inputs(B, N, device):
return torch.randn(B, N, device=device, dtype=torch.float32)
# NOTE: We included this custom benchmark function to avoid adding the triton dependency
# and allow users to run the example on CPUs. However, if you have an NVIDIA GPU,
# you can use the triton.testing.do_bench function instead.
# Refer to examples/triton/evaluate.py for an example.
# triton.testing.do_bench documentation: https://triton-lang.org/main/python-api/generated/triton.testing.do_bench.html
@torch.no_grad()
def bench(f, inputs, n_warmup, n_rep):
device_type = inputs.device.type
# warm up
for _ in range(n_warmup):
f(inputs) # noqa
if device_type == "cuda":
torch.cuda.synchronize()
elif device_type == "mps":
torch.mps.synchronize()
# benchmark
t_avg = 0.0
for _ in range(n_rep):
# time forward pass
start_time = time.time()
f(inputs)
t_avg += time.time() - start_time
# Synchronize after each iteration
if device_type == "cuda":
torch.cuda.synchronize()
elif device_type == "mps":
torch.mps.synchronize()
t_avg /= n_rep * 1e-3
return t_avg
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--device", default="cpu", type=str)
args = parser.parse_args()
# benchmark parameters
n_correctness_trials = 10
correctness_tolerance = 1e-5
n_warmup = 1000
n_rep = 5000
# init and input parameters
batch_size, input_size, hidden_size, scaling_factor = 128, 10, 20, 1.5
# load solution module
try:
torch.manual_seed(0)
solution_module = load_module_from_path(args.path, add_to_sys_modules=False)
solution_model = solution_module.Model(input_size, hidden_size, scaling_factor).to(args.device)
assert isinstance(solution_model, nn.Module)
assert hasattr(solution_model, "forward")
except Exception:
print(f"Candidate module initialization failed: {traceback.format_exc()}")
exit(1)
torch.manual_seed(0)
baseline_model = Model(input_size, hidden_size, scaling_factor).to(args.device)
# measure correctness
max_diff_avg = 0
for _ in range(n_correctness_trials):
inputs = get_inputs(batch_size, input_size, args.device)
optimized_output = solution_model(inputs)
if torch.isnan(optimized_output).any():
print("Incorrect solution: NaN detected in optimized model output")
if torch.isinf(optimized_output).any():
print("Incorrect solution: Inf detected in optimized model output")
baseline_output = baseline_model(inputs)
max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
max_diff_avg /= n_correctness_trials
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
if max_diff_avg > correctness_tolerance:
print("Incorrect solution: max float diff is too high")
# measure performance
inputs = get_inputs(batch_size, input_size, args.device)
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
print(f"baseline time: {t_avg_baseline:.2f}ms")
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
print(f"optimized time: {t_avg_optimized:.2f}ms")
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")Run Weco
Now run Weco to optimize your code:
weco run --source module.py \
--eval-command "python evaluate.py --path module.py" \
--metric speedup \
--goal maximize \
--steps 15 \
--model o4-mini \
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."weco run --source module.py ^
--eval-command "python evaluate.py --path module.py" ^
--metric speedup ^
--goal maximize ^
--steps 15 ^
--model o4-mini ^
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."Or in PowerShell:
weco run --source module.py `
--eval-command "python evaluate.py --path module.py" `
--metric speedup `
--goal maximize `
--steps 15 `
--model o4-mini `
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."Here's what you can expect to see (keep an eye on that Best Solution panel):

Tips:
- If you have an NVIDIA GPU, change the device in the
--eval-commandtocuda. - If you are running this on Apple Silicon, set it to
mps.
What's Next?
- Advanced GPU optimization: Try Triton or CUDA kernels
- Different optimization types: Explore Model Development or Prompt Engineering
- Better evaluation scripts: Learn Writing Good Evaluation Scripts
- All command options: Check the CLI Reference