What is the difference between active learning (and active sampling)
and data filtering? And why do we treat data selection differently
during training versus before training?
This post explores the fundamental distinction between active dataset
selection and data filtering, which we will phrase as “selection
vs. rejection” using an appeal to approximate submodularity. By
understanding the nuance, we can better relate the approaches to each
other.
Note this post is not concerned with the labeling aspects of active
learning. While active learning is concerned with labeling the most
informative samples (and active sampling is concerned with training on
the most informative sample, already knowing all the labels), we are
only interested in differences to data filtering. While data filtering
is usually performed on already labeled data to speed up training, it
could also be about early filtering of data we are absolutely not
interested in labeling similar to active learning, for example. Thus,
this post is really about the distinction between active selection and
data filtering.
More generally, the distinction between active selection and data
filtering is also relevant for unsupervised learning, such as
pre-training, where there are no labels in the first place.
We will see that the key difference stems from two aspects:
- when we make selection decisions (operational context); and
- how we use informativeness scores.

Figure 1: Active learning vs. data filtering: A conceptual
comparison of how these approaches evaluate and use sample
informativeness.
The Core Distinction:
Timing and Direction
The operational context, that is when we make selection decisions,
fundamentally shapes how these methods work:
- Active Learning/Sampling makes online/on-policy
selections during training; while
- Data Filtering makes offline/off-policy selections
before training begins.
This timing difference naturally leads to different ways of using
informativeness scores:
- Active Learning/Sampling selects the most
informative samples; while
- Data Filtering rejects the least informative
samples.
While selecting or removing might seem equivalent mathematically
(accepting 10% is the same as rejecting 90%), the operational context
leads to different practical approaches and approximations, and more
importantly, different approximations.
But why should we phrase this as “selection vs. rejection”?
In short, the answer lies in how informativeness of samples behaves
as training data accumulate: namely, it is (approximately)
submodular.
To see why selection and removal differ, we first need to understand
the properties of “informativeness.” Data subset selection approaches
fundamentally rely on approximating the answer to the question: how much
will training a model on a specific sample benefit its performance?
Crucially, informativeness often exhibits
submodularity: the marginal gain of adding a new data
point decreases as more data points are included. Each new data point we
add to our training set provides less additional information than the
previous one. This happens because new samples often contain information
that overlaps with what we’ve already learned from previous samples. As
the training set grows, new samples offer diminishing returns in terms
of informativeness.
Submodularity is a well-understood property (see e.g. Nemhauser et
al, 19781), and the expected information gain,
which provides a natural measure of informativeness, is submodular or at
least approximately submodular in many cases (Das and Kempe, 20182). (In the case of the
parameter-based expected information gain, it is always submodular.)
But why does this matter for active learning versus filtering? Let’s
break it down:
For Active Learning: Submodularity tells us we
need to constantly re-evaluate which samples are most informative. What
was highly informative previously might not be informative anymore,
especially after we’ve trained on similar samples.
For Data Filtering: Submodularity gives us
confidence that samples deemed uninformative at the start will likely
remain so throughout training. This makes early filtering a safe
operation.
Active Learning: Iterative
Selection
Active learning methods operate iteratively, computing
informativeness scores in an online manner, typically at each training
iteration. At each step, that currently appear most informative for the
model’s learning process are selected for training.
Some active learning strategies, like BatchBALD, take this a step
further. They explicitly use submodularity to efficiently select batches
of samples. But whether explicit or implicit, the key insight remains:
we need to continuously reassess sample importance as our model learns
and evolves.
Data Filtering: Early
Rejection
What makes data filtering fundamentally different? It’s all about
timing.
Data filtering takes a different approach by evaluating
informativeness offline, before training begins. The goal is to identify
and remove less informative samples to streamline the training process.
But how can we be confident in making such early decisions?
Again, submodularity provides the answer. With (approximate)
submodularity, we can be reasonably confident that samples deemed
uninformative at the start won’t suddenly become highly informative
later. This property gives us the theoretical backing to make these
early filtering decisions.

Figure 2: Visualization of how sample informativeness evolves during
active learning.
The Bayesian Perspective
Both active learning and data filtering can be understood through the
lens of information theory, which provides a framework to quantify
informativeness. Active learning methods like BALD (Houlsby et al,
20113) select samples that maximize the
expected information gain, that is the mutual information between model
parameters and predictions. Data filtering, meanwhile, can be viewed as
removing samples that would contribute minimally to the posterior
update.
More generally, many non-Bayesian methods can be interpreted as
approximations of these Bayesian approaches (Kirsch et al, 20224). It would be interesting to explore
this connection further.
A Practical Example:
MNIST Active Learning
Let’s make this concrete with a real example. We conducted an active
learning experiment on MNIST using a LeNet-5 model with Monte Carlo
dropout. The experiment iteratively selects the most informative samples
based on BALD scores and trains the model on the acquired samples.
The animation in Figure 3 shows the evolution of BALD scores of the
training set using an Exponential Moving Average (EMA) over iterations
using a fixed order of samples (sorted by initial informativeness).
Because the EMA smooths out noise in the informativeness estimates, this
animation is only illustrative of the true underlying dynamics. However,
it is a reasonable practical choice for visualizing trends. For a more
accurate visualization, one would want to use better estimators of the
EIG scores or many more samples.
We see that the informativeness of samples decreases over time, as
expected. We also see that the order of samples changes a lot as the
previously most informative samples are added to the training set.

Figure 3: Evolution of BALD scores over time, smoothed using
Exponential Moving Average (EMA).
Conclusion
We have examined a few key insights about the relationship between
active learning and data filtering:
- The timing of selection decisions fundamentally shapes how these
methods work;
- Submodularity provides theoretical backing for both approaches;
and
- Real-world applications often require balancing theoretical ideals
with practical constraints.
While submodularity is often assumed, examining cases where sample
informativeness can increase under certain conditions instead of
decrease might still be underexplored for filtering—at least in a
principled fashion as far as I know.
There’s still much to explore about cases where sample
informativeness behaves in unexpected ways:
- Understanding when and why informativeness patterns deviate from
theory;
- Developing more robust methods that account for these deviations;
and
- Creating hybrid approaches that combine the selection and rejection
approaches.
Conceptual Illustrations
The figures above illustrate the key concepts we’ve discussed. Let me
explain how they were created:
Figure 1 shows a conceptual visualization of the informativeness
curve, where samples are sorted by their informativeness (y-axis) along
the x-axis. This curve demonstrates the diminishing returns property
characteristic of submodularity. The visualization highlights:
- The data filtering region (red shaded area) shows
samples that would be filtered out because their informativeness falls
below a threshold.
- The active learning region (blue shaded area)
represents samples that would be prioritized by active learning
approaches because of their high informativeness.
Figure 2 is an animation that simulates how active learning works
over time:
- We start with a distribution of samples with varying informativeness
levels.
- At each step, the most informative sample (highest point) is
selected for training.
- After selection, that sample’s informativeness drops to zero (it’s
been “used”).
- The informativeness of other samples also decreases (with some
random noise), reflecting how the value of remaining samples changes as
the model learns.
This animation captures the dynamic, iterative nature of active
learning, where the most informative sample is constantly changing as
training progresses. It visually demonstrates why we need to re-evaluate
informativeness at each step rather than making all selection decisions
upfront.
Both visualizations were created using matplotlib with an XKCD-style
aesthetic to make the concepts more approachable and intuitive.
Illustration Code
"""Generate an XKCD-style illustration of submodularity and how active learning vs.
data filtering interact with the informativeness curve.
The script produces both static PNG/SVG files and an animated GIF showing how
informativeness changes over time as data gets trained on.
Usage:
python scripts/submodularity_curve.py # default path
python scripts/submodularity_curve.py --out /tmp/figure.png
"""
# %%
from __future__ import annotations
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter
# %%
# Synthetic submodular-like curve: diminishing returns.
x = np.linspace(1, 10000, 10000)
y_base = 1 / (0.9 / 99.9 * x + (1 - 0.9 / 99.9))
y_filtering_threshold = 0.05 # filtering threshold: keep samples with y >= threshold.
y_active_threshold = 0.2
filtering_color = "red"
active_color = "blue"
# First create the static plots
with plt.xkcd():
fig, ax = plt.subplots(figsize=(8, 5))
# Plot the diminishing-returns curve.
ax.plot(x, y_base, color="black")
# Shaded region for samples that would be *filtered out* (y below threshold).
ax.fill_between(x, 0, y_base, where=y_base < y_filtering_threshold, color=filtering_color, alpha=0.1,
label="Data filtering region (offline)")
# Horizontal dashed line showing the filtering threshold.
ax.axhline(y_filtering_threshold, color=filtering_color, linestyle="--", linewidth=1)
# Vertical dashed line showing the active learning region.
ax.axvline(x[np.argmax(y_base <= y_active_threshold)], color=active_color, linestyle="--", linewidth=1)
# Shaded vertical region for samples that would be considered for active learning.
ax.fill_betweenx(y_base, 0, x, where=y_base > y_active_threshold, color=active_color, alpha=0.1,
label="Active learning region (online)")
# Styling.
ax.set_xlabel("Samples sorted by informativeness")
ax.set_ylabel("Informativeness")
ax.set_title("Active Selection vs. Data Filtering")
ax.set_ylim(0, 1.05)
ax.set_xlim(0, x.max())
ax.legend()
ax.grid(False)
out_path = Path("active_vs_filtering_xkcd.png")
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=150, bbox_inches="tight")
out_path = Path("active_vs_filtering_xkcd.svg")
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.show()
#%%
# Now create the animated version
n_frames = 40
top_k = 1
current_y = y_base.copy()
with plt.xkcd():
fig_anim, ax_anim = plt.subplots(figsize=(8, 5))
scatter = ax_anim.scatter([], [], color="black", s=1)
selected_scatter = ax_anim.scatter([], [], color=active_color, s=100, alpha=0.1)
# Horizontal dashed line showing the filtering threshold.
ax_anim.axhline(y_filtering_threshold, color=filtering_color, linestyle="--", linewidth=1)
# Styling
ax_anim.set_xlabel("Samples sorted by informativeness")
ax_anim.set_ylabel("Informativeness")
ax_anim.set_title("Simulation of Top-1 Active Selection")
ax_anim.set_ylim(0, 1.05)
ax_anim.set_xlim(0, 1000)
ax_anim.grid(False)
def init():
scatter.set_offsets(np.c_[x, y_base])
return (scatter, selected_scatter)
def animate(frame):
print(frame)
global current_y
# Add noise to the decay
if frame > 1:
top_k_indices = np.argsort(current_y)[-top_k:]
current_y[top_k_indices] = 0.0
# Decay the other points
noise = np.clip(np.random.gumbel(0, 0.1, len(current_y)), 0, 1)
current_y = current_y * (1 - noise)
# Set the top 10 points to 0
top_k_indices = np.argsort(current_y)[-top_k:]
# Append the top 10 points to the selected scatter
selected_scatter.set_offsets(np.c_[x[top_k_indices].copy(), current_y[top_k_indices].copy()])
scatter.set_offsets(np.c_[x, current_y])
return (scatter, selected_scatter)
anim = FuncAnimation(fig_anim, animate, init_func=init,
frames=n_frames, interval=200, blit=False, save_count=40)
fig_anim.tight_layout()
writer = PillowWriter(fps=5)
out_path = Path("active_selection_animation.gif")
out_path.parent.mkdir(parents=True, exist_ok=True)
anim.save(out_path, writer=writer)
# %%
MNIST Active Learning
Experiment
In our MNIST active learning experiment, we used a LeNet-5 model with
Monte Carlo dropout to iteratively select the most informative samples
based on BALD scores. The experiment started with a small initial
labeled set and, at each iteration, acquired the most informative
unlabeled samples for training. The BALD scores were calculated using
multiple Monte Carlo samples to estimate the mutual information between
model parameters and predictions.
The animation in Figure 3 visualizes the evolution of BALD scores
over iterations. It shows how the informativeness of samples changes as
the model learns, with the most informative samples being selected and
their scores dropping to zero after acquisition. This dynamic process
highlights the importance of re-evaluating sample informativeness during
training.
The experiment was implemented in
mnist_al_experiment.py, which handles the model training,
BALD score calculation, and sample acquisition. The animation was
created using bald_animation.py, which retrieves data from
Weights & Biases (wandb) and generates the visualization of BALD
score evolution.
Experiment Code
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm
import wandb
# --- Configuration ---
DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
# MNIST Hyperparameters
N_CLASSES = 10
IMG_WIDTH = 28
IMG_HEIGHT = 28
N_CHANNELS = 1
# LeNet Model Definition (LeNet-5 architecture with MC Dropout)
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv1_drop = nn.Dropout2d()
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(1024, 128)
self.fc1_drop = nn.Dropout()
self.fc2 = nn.Linear(128, N_CLASSES)
def forward(self, input: torch.Tensor):
input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
input = input.view(-1, 1024)
input = F.relu(self.fc1_drop(self.fc1(input)))
input = self.fc2(input)
return input
def bootstrap_sample(dataset, sample_size=None, replace=True, random_state=None):
"""
Creates a bootstrap sample from the dataset.
Args:
dataset: PyTorch dataset to sample from
sample_size: Size of the bootstrap sample (defaults to len(dataset))
replace: Whether to sample with replacement (True for bootstrap)
random_state: Random seed for reproducibility
Returns:
A PyTorch Subset containing the bootstrapped samples
"""
if sample_size is None:
sample_size = len(dataset)
rng = np.random.default_rng(random_state)
# Generate indices with replacement
indices = rng.choice(len(dataset), size=sample_size, replace=replace)
# Return a subset of the dataset with the selected indices
return torch.utils.data.Subset(dataset, indices)
def create_bootstrap_loader(
dataset,
batch_size=64,
sample_size=None,
replace=True,
random_state=None,
num_workers=0,
shuffle=True,
):
"""
Creates a DataLoader with bootstrapped samples from the dataset.
Args:
dataset: PyTorch dataset to sample from
batch_size: Batch size for the DataLoader
sample_size: Size of the bootstrap sample (defaults to len(dataset))
replace: Whether to sample with replacement (True for bootstrap)
random_state: Random seed for reproducibility
num_workers: Number of worker processes for data loading
shuffle: Whether to shuffle the data
Returns:
A DataLoader containing bootstrapped samples
"""
bootstrap_dataset = bootstrap_sample(dataset, sample_size, replace, random_state)
return torch.utils.data.DataLoader(
bootstrap_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
)
# EIG (BALD) Calculation
def get_scores(model, data_loader, n_mc_samples, criterion=None):
"""
Calculates BALD scores for samples in data_loader.
BALD = H(E[p(y|x,w)]) - E[H(p(y|x,w))]
"""
model.train()
all_probs_mc = [] # To store [n_samples_dataset, n_mc_samples, n_classes]
all_labels = []
with torch.inference_mode():
for data, labels in tqdm(data_loader, desc="Calculating BALD Scores"):
data = data.to(DEVICE)
batch_probs_mc = [] # Store MC samples for this batch [n_mc_samples, batch_size, n_classes]
for _ in range(n_mc_samples):
output = model(data)
probs = F.softmax(output, dim=1)
batch_probs_mc.append(probs.unsqueeze(0))
batch_probs_mc = torch.cat(
batch_probs_mc, dim=0
) # [n_mc_samples, batch_size, n_classes]
all_probs_mc.append(
batch_probs_mc.permute(1, 0, 2).cpu()
) # [batch_size, n_mc_samples, n_classes]
all_labels.append(labels) # [batch_size]
all_probs_mc_tensor = torch.cat(
all_probs_mc, dim=0
) # [total_samples, n_mc_samples, n_classes]
all_labels = torch.cat(all_labels, dim=0) # [total_samples]
# Entropy of mean predictions: H(E[p(y|x,w)])
mean_probs = torch.mean(all_probs_mc_tensor, dim=1) # [total_samples, n_classes]
ic_probs = mean_probs * torch.log(mean_probs)
ic_probs[torch.isnan(ic_probs)] = 0.0
entropy_of_mean = -torch.sum(ic_probs, dim=1) # [total_samples]
# Mean of entropy of predictions: E[H(p(y|x,w))]
ic_probs_mc = all_probs_mc_tensor * torch.log(all_probs_mc_tensor)
ic_probs_mc[torch.isnan(ic_probs_mc)] = 0.0
entropy_per_mc_sample = -torch.sum(
ic_probs_mc, dim=2
) # [total_samples, n_mc_samples]
mean_of_entropy = torch.mean(entropy_per_mc_sample, dim=1) # [total_samples]
bald_scores = entropy_of_mean - mean_of_entropy
acc = (mean_probs.argmax(dim=1) == all_labels).float().mean() * 100.0
if criterion is not None:
loss = criterion(mean_probs.log(), all_labels)
else:
loss = -1.0
return bald_scores.numpy(), acc, loss
# Training function
def train_model(model, train_loader, optimizer, criterion, epochs):
model.train() # Set to train mode (enables dropout, batchnorm updates etc.)
pbar = tqdm(range(epochs), desc="Training Epochs")
for epoch in pbar:
epoch_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
pbar.set_postfix(loss=epoch_loss/len(train_loader), acc=100.*correct/total)
def plot_scores(scores):
# Sort scores descending
sorted_scores = -np.sort(-scores)
# Plot the top 100 scores
plt.plot(sorted_scores)
plt.show()
# --- Data Loading and Preparation ---
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# Load full training and test datasets
full_train_dataset = datasets.MNIST(
"./data", train=True, download=True, transform=transform
)
# Subset the full training dataset to 10000 random samples
# Use numpy's default_rng with fixed seed for reproducibility
rng = np.random.default_rng(1)
indices = rng.permutation(len(full_train_dataset))[:10000]
full_train_dataset = torch.utils.data.Subset(full_train_dataset, indices)
all_train_loader_for_scores = DataLoader(
full_train_dataset, batch_size=1024, shuffle=False
)
test_dataset = datasets.MNIST("./data", train=False, transform=transform)
test_loader = DataLoader(
test_dataset, batch_size=1024, shuffle=False
) # For final eval & test scores
# --- Active Learning Setup ---
N_INITIAL_LABELED = 60 # Number of initial labeled samples
N_ACQUIRE_PER_ITER = 1 # Number of samples to acquire each iteration
N_ACTIVE_LEARNING_ITERATIONS = 100
N_MC_SAMPLES_EIG = 32 # Number of MC samples for EIG
TRAIN_EPOCHS_PER_ITER = 3 # Epochs to train model at each AL iteration
LEARNING_RATE = 0.0005
BATCH_SIZE_TRAIN = 64 # Batch size for training
N_MIN_SAMPLES_PER_EPOCH = 10_000
# Initialize wandb
wandb.init(
project="blog-active-learning-vs-filtering",
config={
"initial_labeled_samples": N_INITIAL_LABELED,
"acquire_per_iter": N_ACQUIRE_PER_ITER,
"al_iterations": N_ACTIVE_LEARNING_ITERATIONS,
"mc_samples_eig": N_MC_SAMPLES_EIG,
"epochs_per_iter": TRAIN_EPOCHS_PER_ITER,
"learning_rate": LEARNING_RATE,
"batch_size": BATCH_SIZE_TRAIN,
"model": "LeNet-5 with MC Dropout",
"dataset": "MNIST",
"device": str(DEVICE),
}
)
# Create initial labeled and unlabeled pools
num_train_samples = len(full_train_dataset)
all_indices = np.arange(num_train_samples)
np.random.shuffle(all_indices)
initial_labeled_indices = list(all_indices[:N_INITIAL_LABELED])
current_unlabeled_indices = list(all_indices[N_INITIAL_LABELED:])
# Store all acquisition scores
# Format: list of dicts, each dict is one AL iteration
# {'iteration': i, 'labeled_indices': [...], 'unlabeled_indices': [...],
# 'scores_for_all_train_samples': np.array([...]), 'scores_for_test_samples': np.array([...])}
all_scores_history = []
# --- Active Learning Loop ---
print("Starting Active Learning Loop...")
current_labeled_indices_set = set(initial_labeled_indices)
for al_iteration in tqdm(range(N_ACTIVE_LEARNING_ITERATIONS + 1), desc="Active Learning Iterations"): # +1 for initial state and after last acquisition
print(f"\n--- Active Learning Iteration: {al_iteration} ---")
print(f"Currently labeled samples: {len(current_labeled_indices_set)}")
# 1. Create current labeled dataset and loader
labeled_subset = Subset(full_train_dataset, list(current_labeled_indices_set))
labeled_loader = create_bootstrap_loader(
labeled_subset,
batch_size=BATCH_SIZE_TRAIN,
shuffle=True,
random_state=RANDOM_SEED,
sample_size=max(N_MIN_SAMPLES_PER_EPOCH, len(labeled_subset)),
replace=N_MIN_SAMPLES_PER_EPOCH > len(labeled_subset),
)
# 2. Initialize or re-initialize model and optimizer
model = LeNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
# 3. Train model on current labeled data (unless it's iteration 0 before any acquisition)
if N_INITIAL_LABELED > 0: # Train if there's data
print("Training model...")
train_model(
model, labeled_loader, optimizer, criterion, epochs=TRAIN_EPOCHS_PER_ITER
)
# 4. Calculate EIG scores for all original training samples and test samples
print("Calculating EIG scores for test samples...")
scores_test_samples, test_acc, test_loss = get_scores(model, test_loader, N_MC_SAMPLES_EIG, criterion)
print(f"Model trained. Test Acc: {test_acc:.2f}%, Test Loss: {test_loss:.4f}")
print("Calculating EIG scores for all original training samples...")
# Create a loader for ALL original training samples to get their scores
scores_all_train_samples, _, _ = get_scores(
model, all_train_loader_for_scores, N_MC_SAMPLES_EIG
)
# Plot the scores
plt.figure(figsize=(10, 5))
plt.title("BALD Scores for all original training samples")
plt.ylabel("BALD Score")
plot_scores(scores_all_train_samples)
plt.show()
plt.figure(figsize=(10, 5))
plt.title("BALD Scores for all test samples")
plt.ylabel("BALD Score")
plot_scores(scores_test_samples)
plt.show()
# Store scores for this iteration
iteration_scores_data = {
"iteration": al_iteration,
"current_labeled_indices": sorted(list(current_labeled_indices_set)),
"current_unlabeled_indices": sorted(
current_unlabeled_indices
), # These are indices FROM the original full_train_dataset
"scores_for_all_train_samples": scores_all_train_samples, # Order matches full_train_dataset
"scores_for_test_samples": scores_test_samples, # Order matches test_dataset
}
all_scores_history.append(iteration_scores_data)
print(f"Scores stored for iteration {al_iteration}.")
# Log metrics to wandb
wandb_log_data = {
"iteration": al_iteration,
"num_labeled_samples": len(current_labeled_indices_set),
"test_accuracy": test_acc,
"test_loss": test_loss,
"avg_train_bald_score": np.mean(scores_all_train_samples),
"avg_test_bald_score": np.mean(scores_test_samples),
"max_train_bald_score": np.max(scores_all_train_samples),
"max_test_bald_score": np.max(scores_test_samples),
}
# Log histograms of BALD scores
wandb.log(wandb_log_data, commit=False)
# Log histograms of BALD scores (optional)
# Create a table with all the scores
train_scores_table = wandb.Table(columns=["index", "bald_score"])
for idx, score in enumerate(scores_all_train_samples):
train_scores_table.add_data(idx, float(score))
test_scores_table = wandb.Table(columns=["index", "bald_score"])
for idx, score in enumerate(scores_test_samples):
test_scores_table.add_data(idx, float(score))
# Create plotly figures for BALD scores
# Create sorted scores for better visualization
sorted_train_scores = sorted(scores_all_train_samples, reverse=True)
sorted_test_scores = sorted(scores_test_samples, reverse=True)
# Create dataframes for plotly
train_df = pd.DataFrame({
"index": range(len(sorted_train_scores)),
"bald_score": sorted_train_scores
})
test_df = pd.DataFrame({
"index": range(len(sorted_test_scores)),
"bald_score": sorted_test_scores
})
# Create plotly figures
train_fig = px.line(train_df, x="index", y="bald_score",
title=f"Sorted BALD Scores for Training Samples (Iteration {al_iteration})")
train_fig.update_layout(xaxis_title="Sample Index (sorted)", yaxis_title="BALD Score")
test_fig = px.line(test_df, x="index", y="bald_score",
title=f"Sorted BALD Scores for Test Samples (Iteration {al_iteration})")
test_fig.update_layout(xaxis_title="Sample Index (sorted)", yaxis_title="BALD Score")
# Log plotly figures to wandb
wandb.log({
"train_bald_scores_plot": wandb.Plotly(train_fig),
"test_bald_scores_plot": wandb.Plotly(test_fig),
"train_bald_scores": train_scores_table,
"test_bald_scores": test_scores_table,
}, commit=False)
# 5. If it's not the last iteration, perform acquisition
if al_iteration < N_ACTIVE_LEARNING_ITERATIONS:
if not current_unlabeled_indices:
print("No more unlabeled samples to acquire. Stopping.")
break
print("Acquiring new samples...")
# We need scores only for the *currently unlabeled* samples to decide which to pick
# `scores_all_train_samples` contains scores for *all* original training samples.
# We filter these down to only the unlabeled ones.
unlabeled_scores_map = {
idx: scores_all_train_samples[idx] for idx in current_unlabeled_indices
}
# Sort unlabeled samples by their EIG score (descending)
sorted_unlabeled_by_score = sorted(
unlabeled_scores_map.items(), key=lambda item: item[1], reverse=True
)
# Select top N_ACQUIRE_PER_ITER samples
num_to_acquire = min(N_ACQUIRE_PER_ITER, len(sorted_unlabeled_by_score))
acquired_indices_scores = sorted_unlabeled_by_score[:num_to_acquire]
acquired_indices = [idx for idx, score in acquired_indices_scores]
acquired_scores = [score for idx, score in acquired_indices_scores]
if not acquired_indices:
print(
"Could not acquire any new samples (perhaps scores were uniform or list was empty). Stopping."
)
break
print(f"Acquired {len(acquired_indices)} new samples ({np.mean(acquired_scores):.2f} avg score).")
# Log acquisition information
wandb.log({
"acquisition_step": al_iteration,
"acquired_indices": acquired_indices,
"acquired_scores": acquired_scores,
"avg_acquired_score": np.mean(acquired_scores) if acquired_scores else 0,
}, commit=False)
# Add to labeled set and remove from unlabeled pool
for idx in acquired_indices:
current_labeled_indices_set.add(idx)
current_unlabeled_indices.remove(
idx
) # Keep this list of original indices up to date
else:
print("Final iteration reached. No more acquisitions.")
wandb.log({}, step=al_iteration, commit=True)
# Finish wandb run
wandb.finish()
📊 View Experiment Results on W&B
Animation Code
#!/usr/bin/env python3
"""
Create an animation showing how BALD scores evolve during active learning.
This script retrieves data from wandb and creates a similar animation to the one in visualizations.py.
Uses BALD scores without sorting to show the actual distribution of uncertainty.
"""
#%%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from pathlib import Path
import json
import pandas as pd
from tqdm.auto import tqdm
from wandb.apis.public import Api as WandbApi
# Initialize wandb API
api = WandbApi()
# Find the latest run in the project
runs = api.runs("blog-active-learning-vs-filtering", order="-created_at")
latest_run = runs[0] # Most recent run
print(f"Analyzing run: {latest_run.name} ({latest_run.id})")
#%%
# Get data from the run - we need to extract BALD scores and acquisition info
scores_by_iter = {}
acquired_indices_by_iter = {}
train_subset_indices = {} # To store which indices were already labeled
# Fetch the history data
print("Fetching history data...")
scores_key = "train_bald_scores"
history = latest_run.history(keys=["iteration", scores_key], pandas=False)
#%%
# Process the history data
for row in tqdm(history):
if "iteration" not in row:
continue
iteration = row["iteration"]
# Get BALD scores if available in this row
if scores_key in row:
try:
# Try to get scores from the run table
table_data = row[scores_key]
# Load the artifact
table_file = latest_run.file(table_data['path']).download(replace=True).read()
table_data = json.loads(table_file)
df = pd.DataFrame(table_data["data"], columns=table_data["columns"])
scores_by_iter[iteration] = df["bald_score"].values
except Exception as e:
print(f"Error extracting {scores_key} for iteration {iteration}: {e}")
# Get acquired indices if available
if "acquired_indices" in row:
acquired_indices_by_iter[iteration] = row["acquired_indices"]
# Get currently labeled indices
if "current_labeled_indices" in row:
train_subset_indices[iteration] = row["current_labeled_indices"]
#%%
# If we still don't have data, exit
if not scores_by_iter:
print(f"Failed to retrieve {scores_key} from the run. Exiting.")
raise ValueError("Failed to retrieve BALD scores from the run.")
#%%
# Sort iterations and prepare animation data
sorted_iterations = sorted(scores_by_iter.keys())
all_scores = [scores_by_iter[i].copy() for i in sorted_iterations]
acquired_indices = []
for scores in all_scores:
for i in acquired_indices:
scores[i] = 0.
# Get the argmax score
max_idx = np.argmax(scores)
acquired_indices.append(max_idx.item())
# Pop last acquired index
acquired_indices.pop()
assert np.all(all_scores[-1][acquired_indices] == 0.)
print(f"Acquired indices: {acquired_indices}")
print(f"Found {scores_key} for {len(all_scores)} iterations")
print(f"First iteration has {len(all_scores[0])} samples")
#%%
# Create the animation with XKCD style but using EMA of the scores
ema_decay = 0.9
all_scores_ema = []
current_scores = all_scores[0].copy()
for i in range(len(all_scores)):
current_scores = ema_decay * current_scores + (1.0 - ema_decay) * all_scores[i]
for j in range(i):
current_scores[acquired_indices[j]] = 0.
all_scores_ema.append(current_scores)
#%%
with plt.xkcd():
fig, ax = plt.subplots(figsize=(10, 6))
# Set up initial plot - we'll use scatter plot for non-sorted scores
max_points = min(10000, max([len(scores) for scores in all_scores_ema]))
scatter = ax.scatter([], [], color="black", s=2)
# For highlighted points we'll use a different color
highlight_scatter = ax.scatter([], [], color="blue", s=100, alpha=0.5)
# Title and labels
ax.set_title("BALD Score Evolution (Non-sorted w/ EMA)")
ax.set_xlabel("Sample Index")
ax.set_ylabel("Informativeness (BALD score)")
ax.set_xlim(0, max_points)
# Find max BALD score for consistent y-axis scaling
max_score = max([max(scores) for scores in all_scores_ema])
ax.set_ylim(0, max_score * 1.1)
# Text annotations
iter_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12)
labeled_text = ax.text(0.02, 0.90, '', transform=ax.transAxes, fontsize=10)
# Get the sorted indices of the first iteration
sorted_indices = np.argsort(all_scores[0])[::-1]
inv_sorted_indices = np.argsort(sorted_indices)
def init():
# Use the raw scores (no sorting)
current_scores = np.array(all_scores_ema[0])
x_data = np.arange(len(current_scores))
scatter.set_offsets(np.c_[x_data, current_scores])
# Highlight top uncertain points
highlight_scatter.set_offsets(np.c_[[], []])
# Set iteration text
iter_text.set_text(f"Iteration: {sorted_iterations[0]}")
return scatter, highlight_scatter, iter_text, labeled_text
def animate(frame_idx):
if frame_idx < len(sorted_iterations):
iteration = sorted_iterations[frame_idx]
# Compute an EMA of the past scores
current_scores = np.array(all_scores_ema[frame_idx])
# Use raw scores without sorting
x_data = inv_sorted_indices[np.arange(len(current_scores))]
# Update main scatter plot
scatter.set_offsets(np.c_[x_data, current_scores.copy()])
# Highlight top uncertain points (these will change each iteration)
# top_indices = acquired_indices[:frame_idx]
# highlight_scatter.set_offsets(np.c_[inv_sorted_indices[top_indices], current_scores[top_indices].copy()])
# assert np.all(current_scores[top_indices[:1]] == 0.), current_scores[top_indices]
# current_scores[top_indices] = 0.
# Update text
iter_text.set_text(f"Iteration: {iteration}")
return scatter, highlight_scatter, iter_text, labeled_text
# Create animation
num_frames = len(sorted_iterations)
anim = FuncAnimation(fig, animate, init_func=init,
frames=num_frames, interval=50, blit=True)
# Add legend
ax.legend()
# Save animation
writer = PillowWriter(fps=1000/50)
output_path = Path("bald_scores_ema_animation.gif")
print(f"Saving animation to {output_path}")
anim.save(output_path, writer=writer)
print("Done!")
# %%