SmartCane/benchmark_gpu_vs_cpu.py
2026-01-06 14:17:37 +01:00

83 lines
2.2 KiB
Python

import torch
import torch.nn as nn
import time
print("=" * 80)
print("PYTORCH GPU vs CPU BENCHMARK TEST")
print("=" * 80)
# Model definition
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 1000)
self.fc2 = nn.Linear(1000, 1000)
self.fc3 = nn.Linear(1000, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Dummy data - larger dataset
x = torch.randn(100000, 784)
y = torch.randint(0, 10, (100000,))
# Loss function
criterion = nn.CrossEntropyLoss()
print("\n1. GPU TRAINING")
print("-" * 80)
model_gpu = SimpleModel().cuda() # Move to GPU
optimizer_gpu = torch.optim.Adam(model_gpu.parameters())
x_gpu = x.cuda()
y_gpu = y.cuda()
print(f"Device: {next(model_gpu.parameters()).device}")
print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
start_time = time.time()
for epoch in range(20):
optimizer_gpu.zero_grad()
outputs = model_gpu(x_gpu)
loss = criterion(outputs, y_gpu)
loss.backward()
optimizer_gpu.step()
if (epoch + 1) % 5 == 0:
print(f" Epoch {epoch+1}/20 - Loss: {loss.item():.4f}")
gpu_time = time.time() - start_time
print(f"\nGPU training time: {gpu_time:.2f} seconds")
print("\n2. CPU TRAINING")
print("-" * 80)
model_cpu = SimpleModel().cpu() # Stay on CPU
optimizer_cpu = torch.optim.Adam(model_cpu.parameters())
x_cpu = x.cpu()
y_cpu = y.cpu()
print(f"Device: {next(model_cpu.parameters()).device}")
start_time = time.time()
for epoch in range(20):
optimizer_cpu.zero_grad()
outputs = model_cpu(x_cpu)
loss = criterion(outputs, y_cpu)
loss.backward()
optimizer_cpu.step()
if (epoch + 1) % 5 == 0:
print(f" Epoch {epoch+1}/20 - Loss: {loss.item():.4f}")
cpu_time = time.time() - start_time
print(f"\nCPU training time: {cpu_time:.2f} seconds")
print("\n" + "=" * 80)
print("RESULTS")
print("=" * 80)
print(f"GPU time: {gpu_time:.2f} seconds")
print(f"CPU time: {cpu_time:.2f} seconds")
print(f"Speedup: {cpu_time / gpu_time:.1f}x faster on GPU")
print("=" * 80)