A very minimal training loop abstraction for PyTorch.
- Minimal: Only PyTorch as dependency
- Flexible: Use existing PyTorch models, no need to subclass
- Extensible: Callback system for custom behavior
- Automatic: Handles device management, mixed precision, gradient accumulation
- No magic: Simple, readable code that does what you expect
Automatic device management (CPU/CUDA), mixed precision training, gradient accumulation, gradient clipping, checkpoints with resume capability, callback system for extensibility, early stopping, LR scheduling, progress tracking, exponential moving average (EMA), etc.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchtl import Trainer
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
trainer = Trainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
device='cuda',
mixed_precision=True
)
history = trainer.fit(train_loader, val_loader, epochs=10)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchtl import Trainer
X_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
trainer = Trainer(model, optimizer, loss_fn)
trainer.fit(train_loader, epochs=10)
X_val = torch.randn(200, 10)
y_val = torch.randn(200, 1)
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=32)
history = trainer.fit(train_loader, val_loader, epochs=10)
print(f"Train losses: {history['train_loss']}")
print(f"Val losses: {history['val_loss']}")
trainer = Trainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
mixed_precision=True
)
trainer = Trainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
grad_acc_steps=4
)
trainer = Trainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
max_grad_norm=1.0
)
from torchtl import ProgressCallback
trainer = Trainer(model, optimizer, loss_fn)
trainer.add_callback(ProgressCallback(print_every=100))
trainer.fit(train_loader, epochs=10)
from torchtl import CheckpointCallback
checkpoint_cb = CheckpointCallback(
checkpoint_dir='./checkpoints',
save_every_n_epochs=1,
keep_last_n=3
)
trainer.add_callback(checkpoint_cb)
trainer.fit(train_loader, val_loader, epochs=10)
checkpoint_cb = CheckpointCallback(
checkpoint_dir='./checkpoints',
save_best_only=True,
monitor='val_loss',
mode='min'
)
trainer.add_callback(checkpoint_cb)
trainer.fit(train_loader, val_loader, epochs=10)
from torchtl import EarlyStoppingCallback, StopTraining
early_stop_cb = EarlyStoppingCallback(
patience=5,
monitor='val_loss',
mode='min',
min_delta=0.001
)
trainer.add_callback(early_stop_cb)
try:
trainer.fit(train_loader, val_loader, epochs=100)
except StopTraining as e:
print(f"Training stopped: {e}")
from torchtl import LearningRateSchedulerCallback
from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
scheduler_cb = LearningRateSchedulerCallback(scheduler)
trainer.add_callback(scheduler_cb)
trainer.fit(train_loader, epochs=20)
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3)
scheduler_cb = LearningRateSchedulerCallback(scheduler)
trainer.add_callback(scheduler_cb)
trainer.fit(train_loader, val_loader, epochs=20)
from torchtl import (
ProgressCallback,
CheckpointCallback,
EarlyStoppingCallback,
LearningRateSchedulerCallback
)
trainer.add_callback(ProgressCallback(print_every=50))
trainer.add_callback(CheckpointCallback('./checkpoints', save_best_only=True))
trainer.add_callback(EarlyStoppingCallback(patience=5))
trainer.add_callback(LearningRateSchedulerCallback(scheduler))
trainer.fit(train_loader, val_loader, epochs=100)
trainer.save_checkpoint('./checkpoint.pt')
trainer.load_checkpoint('./checkpoint.pt')
trainer.fit(train_loader, epochs=10)
trainer.save_checkpoint('./checkpoint.pt', best_accuracy=0.95, notes="best model")
trainer.load_checkpoint('./checkpoint.pt', strict=False)
from torchtl import count_params
total_params = count_params(model)
trainable_params = count_params(model, trainable_only=True)
print(f"Total: {total_params}, Trainable: {trainable_params}")
from torchtl import freeze_layers, unfreeze_layers
freeze_layers(model)
unfreeze_layers(model, layer_names=['fc', 'classifier'])
freeze_layers(model, layer_names=['conv1', 'conv2'])
from torchtl import set_seed
set_seed(42)
from torchtl import get_lr, set_lr
current_lr = get_lr(optimizer)
print(f"Current LR: {current_lr}")
set_lr(optimizer, 0.0001)
from torchtl import ExponentialMovingAverage
ema = ExponentialMovingAverage(model, decay=0.999)
for epoch in range(epochs):
trainer.train_epoch(train_loader)
ema.update()
ema.apply_shadow()
val_metrics = trainer.validate(val_loader)
ema.restore()
from torchtl import Callback
class CustomCallback(Callback):
def on_epoch_start(self, trainer):
print(f"Starting epoch {trainer.epoch + 1}")
def on_epoch_end(self, trainer, metrics):
print(f"Epoch {trainer.epoch} finished with loss: {metrics['loss']:.4f}")
def on_batch_end(self, trainer, batch_idx, batch, metrics):
if trainer.global_step % 100 == 0:
print(f"Step {trainer.global_step}, Loss: {metrics['loss']:.4f}")
trainer.add_callback(CustomCallback())
trainer.fit(train_loader, epochs=10)
TorchTL supports multiple batch formats.
batch = (inputs, targets)
batch = {'inputs': inputs, 'targets': targets}
batch = {'input': inputs, 'target': targets}
for epoch in range(10):
train_metrics = trainer.train_epoch(train_loader)
val_metrics = trainer.validate(val_loader)
print(f"Epoch {epoch}: Train Loss={train_metrics['loss']:.4f}, Val Loss={val_metrics['val_loss']:.4f}")
if val_metrics['val_loss'] < best_loss:
best_loss = val_metrics['val_loss']
trainer.save_checkpoint('./best_model.pt')
print(f"Current epoch: {trainer.epoch}")
print(f"Global step: {trainer.global_step}")
print(f"Device: {trainer.device}")
Apache v2.0 License.
.png)

