import torch import torch.nn as nn import time print("=" * 80) print("PYTORCH GPU vs CPU BENCHMARK TEST") print("=" * 80) # Model definition class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc1 = nn.Linear(784, 1000) self.fc2 = nn.Linear(1000, 1000) self.fc3 = nn.Linear(1000, 10) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # Dummy data - larger dataset x = torch.randn(100000, 784) y = torch.randint(0, 10, (100000,)) # Loss function criterion = nn.CrossEntropyLoss() print("\n1. GPU TRAINING") print("-" * 80) model_gpu = SimpleModel().cuda() # Move to GPU optimizer_gpu = torch.optim.Adam(model_gpu.parameters()) x_gpu = x.cuda() y_gpu = y.cuda() print(f"Device: {next(model_gpu.parameters()).device}") print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") start_time = time.time() for epoch in range(20): optimizer_gpu.zero_grad() outputs = model_gpu(x_gpu) loss = criterion(outputs, y_gpu) loss.backward() optimizer_gpu.step() if (epoch + 1) % 5 == 0: print(f" Epoch {epoch+1}/20 - Loss: {loss.item():.4f}") gpu_time = time.time() - start_time print(f"\nGPU training time: {gpu_time:.2f} seconds") print("\n2. CPU TRAINING") print("-" * 80) model_cpu = SimpleModel().cpu() # Stay on CPU optimizer_cpu = torch.optim.Adam(model_cpu.parameters()) x_cpu = x.cpu() y_cpu = y.cpu() print(f"Device: {next(model_cpu.parameters()).device}") start_time = time.time() for epoch in range(20): optimizer_cpu.zero_grad() outputs = model_cpu(x_cpu) loss = criterion(outputs, y_cpu) loss.backward() optimizer_cpu.step() if (epoch + 1) % 5 == 0: print(f" Epoch {epoch+1}/20 - Loss: {loss.item():.4f}") cpu_time = time.time() - start_time print(f"\nCPU training time: {cpu_time:.2f} seconds") print("\n" + "=" * 80) print("RESULTS") print("=" * 80) print(f"GPU time: {gpu_time:.2f} seconds") print(f"CPU time: {cpu_time:.2f} seconds") print(f"Speedup: {cpu_time / gpu_time:.1f}x faster on GPU") print("=" * 80)