TorchTL – A minimal training loop abstraction for PyTorch

3 days ago 1

A very minimal training loop abstraction for PyTorch.

  • Minimal: Only PyTorch as dependency
  • Flexible: Use existing PyTorch models, no need to subclass
  • Extensible: Callback system for custom behavior
  • Automatic: Handles device management, mixed precision, gradient accumulation
  • No magic: Simple, readable code that does what you expect

Automatic device management (CPU/CUDA), mixed precision training, gradient accumulation, gradient clipping, checkpoints with resume capability, callback system for extensibility, early stopping, LR scheduling, progress tracking, exponential moving average (EMA), etc.

import torch import torch.nn as nn from torch.utils.data import DataLoader from torchtl import Trainer model = nn.Linear(10, 1) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, device='cuda', mixed_precision=True ) history = trainer.fit(train_loader, val_loader, epochs=10)
import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from torchtl import Trainer X_train = torch.randn(1000, 10) y_train = torch.randn(1000, 1) train_dataset = TensorDataset(X_train, y_train) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) model = nn.Sequential( nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1) ) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() trainer = Trainer(model, optimizer, loss_fn) trainer.fit(train_loader, epochs=10)
X_val = torch.randn(200, 10) y_val = torch.randn(200, 1) val_dataset = TensorDataset(X_val, y_val) val_loader = DataLoader(val_dataset, batch_size=32) history = trainer.fit(train_loader, val_loader, epochs=10) print(f"Train losses: {history['train_loss']}") print(f"Val losses: {history['val_loss']}")
trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, mixed_precision=True )
trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, grad_acc_steps=4 )
trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, max_grad_norm=1.0 )
from torchtl import ProgressCallback trainer = Trainer(model, optimizer, loss_fn) trainer.add_callback(ProgressCallback(print_every=100)) trainer.fit(train_loader, epochs=10)
from torchtl import CheckpointCallback checkpoint_cb = CheckpointCallback( checkpoint_dir='./checkpoints', save_every_n_epochs=1, keep_last_n=3 ) trainer.add_callback(checkpoint_cb) trainer.fit(train_loader, val_loader, epochs=10)
checkpoint_cb = CheckpointCallback( checkpoint_dir='./checkpoints', save_best_only=True, monitor='val_loss', mode='min' ) trainer.add_callback(checkpoint_cb) trainer.fit(train_loader, val_loader, epochs=10)
from torchtl import EarlyStoppingCallback, StopTraining early_stop_cb = EarlyStoppingCallback( patience=5, monitor='val_loss', mode='min', min_delta=0.001 ) trainer.add_callback(early_stop_cb) try: trainer.fit(train_loader, val_loader, epochs=100) except StopTraining as e: print(f"Training stopped: {e}")
from torchtl import LearningRateSchedulerCallback from torch.optim.lr_scheduler import StepLR scheduler = StepLR(optimizer, step_size=5, gamma=0.1) scheduler_cb = LearningRateSchedulerCallback(scheduler) trainer.add_callback(scheduler_cb) trainer.fit(train_loader, epochs=20)
from torch.optim.lr_scheduler import ReduceLROnPlateau scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3) scheduler_cb = LearningRateSchedulerCallback(scheduler) trainer.add_callback(scheduler_cb) trainer.fit(train_loader, val_loader, epochs=20)
from torchtl import ( ProgressCallback, CheckpointCallback, EarlyStoppingCallback, LearningRateSchedulerCallback ) trainer.add_callback(ProgressCallback(print_every=50)) trainer.add_callback(CheckpointCallback('./checkpoints', save_best_only=True)) trainer.add_callback(EarlyStoppingCallback(patience=5)) trainer.add_callback(LearningRateSchedulerCallback(scheduler)) trainer.fit(train_loader, val_loader, epochs=100)
trainer.save_checkpoint('./checkpoint.pt') trainer.load_checkpoint('./checkpoint.pt') trainer.fit(train_loader, epochs=10)
trainer.save_checkpoint('./checkpoint.pt', best_accuracy=0.95, notes="best model")
trainer.load_checkpoint('./checkpoint.pt', strict=False)
from torchtl import count_params total_params = count_params(model) trainable_params = count_params(model, trainable_only=True) print(f"Total: {total_params}, Trainable: {trainable_params}")
from torchtl import freeze_layers, unfreeze_layers freeze_layers(model) unfreeze_layers(model, layer_names=['fc', 'classifier']) freeze_layers(model, layer_names=['conv1', 'conv2'])
from torchtl import set_seed set_seed(42)
from torchtl import get_lr, set_lr current_lr = get_lr(optimizer) print(f"Current LR: {current_lr}") set_lr(optimizer, 0.0001)

Exponential moving average

from torchtl import ExponentialMovingAverage ema = ExponentialMovingAverage(model, decay=0.999) for epoch in range(epochs): trainer.train_epoch(train_loader) ema.update() ema.apply_shadow() val_metrics = trainer.validate(val_loader) ema.restore()
from torchtl import Callback class CustomCallback(Callback): def on_epoch_start(self, trainer): print(f"Starting epoch {trainer.epoch + 1}") def on_epoch_end(self, trainer, metrics): print(f"Epoch {trainer.epoch} finished with loss: {metrics['loss']:.4f}") def on_batch_end(self, trainer, batch_idx, batch, metrics): if trainer.global_step % 100 == 0: print(f"Step {trainer.global_step}, Loss: {metrics['loss']:.4f}") trainer.add_callback(CustomCallback()) trainer.fit(train_loader, epochs=10)

TorchTL supports multiple batch formats.

batch = (inputs, targets)
batch = {'inputs': inputs, 'targets': targets} batch = {'input': inputs, 'target': targets}
for epoch in range(10): train_metrics = trainer.train_epoch(train_loader) val_metrics = trainer.validate(val_loader) print(f"Epoch {epoch}: Train Loss={train_metrics['loss']:.4f}, Val Loss={val_metrics['val_loss']:.4f}") if val_metrics['val_loss'] < best_loss: best_loss = val_metrics['val_loss'] trainer.save_checkpoint('./best_model.pt')
print(f"Current epoch: {trainer.epoch}") print(f"Global step: {trainer.global_step}") print(f"Device: {trainer.device}")

Apache v2.0 License.

Read Entire Article