Active Learning vs. Data Filtering: Selection vs. Rejection

1 week ago 4

What is the difference between active learning (and active sampling) and data filtering? And why do we treat data selection differently during training versus before training?

This post explores the fundamental distinction between active dataset selection and data filtering, which we will phrase as “selection vs. rejection” using an appeal to approximate submodularity. By understanding the nuance, we can better relate the approaches to each other.

Note this post is not concerned with the labeling aspects of active learning. While active learning is concerned with labeling the most informative samples (and active sampling is concerned with training on the most informative sample, already knowing all the labels), we are only interested in differences to data filtering. While data filtering is usually performed on already labeled data to speed up training, it could also be about early filtering of data we are absolutely not interested in labeling similar to active learning, for example. Thus, this post is really about the distinction between active selection and data filtering.

More generally, the distinction between active selection and data filtering is also relevant for unsupervised learning, such as pre-training, where there are no labels in the first place.

We will see that the key difference stems from two aspects:

  1. when we make selection decisions (operational context); and
  2. how we use informativeness scores.

Figure 1: Active learning vs. data filtering: A conceptual comparison of how these approaches evaluate and use sample informativeness.

The Core Distinction: Timing and Direction

The operational context, that is when we make selection decisions, fundamentally shapes how these methods work:

  • Active Learning/Sampling makes online/on-policy selections during training; while
  • Data Filtering makes offline/off-policy selections before training begins.

This timing difference naturally leads to different ways of using informativeness scores:

  • Active Learning/Sampling selects the most informative samples; while
  • Data Filtering rejects the least informative samples.

While selecting or removing might seem equivalent mathematically (accepting 10% is the same as rejecting 90%), the operational context leads to different practical approaches and approximations, and more importantly, different approximations.

But why should we phrase this as “selection vs. rejection”?

In short, the answer lies in how informativeness of samples behaves as training data accumulate: namely, it is (approximately) submodular.

Informativeness and Submodularity

To see why selection and removal differ, we first need to understand the properties of “informativeness.” Data subset selection approaches fundamentally rely on approximating the answer to the question: how much will training a model on a specific sample benefit its performance?

Crucially, informativeness often exhibits submodularity: the marginal gain of adding a new data point decreases as more data points are included. Each new data point we add to our training set provides less additional information than the previous one. This happens because new samples often contain information that overlaps with what we’ve already learned from previous samples. As the training set grows, new samples offer diminishing returns in terms of informativeness.

Submodularity is a well-understood property (see e.g. Nemhauser et al, 19781), and the expected information gain, which provides a natural measure of informativeness, is submodular or at least approximately submodular in many cases (Das and Kempe, 20182). (In the case of the parameter-based expected information gain, it is always submodular.)

But why does this matter for active learning versus filtering? Let’s break it down:

  1. For Active Learning: Submodularity tells us we need to constantly re-evaluate which samples are most informative. What was highly informative previously might not be informative anymore, especially after we’ve trained on similar samples.

  2. For Data Filtering: Submodularity gives us confidence that samples deemed uninformative at the start will likely remain so throughout training. This makes early filtering a safe operation.

Active Learning: Iterative Selection

Active learning methods operate iteratively, computing informativeness scores in an online manner, typically at each training iteration. At each step, that currently appear most informative for the model’s learning process are selected for training.

Some active learning strategies, like BatchBALD, take this a step further. They explicitly use submodularity to efficiently select batches of samples. But whether explicit or implicit, the key insight remains: we need to continuously reassess sample importance as our model learns and evolves.

Data Filtering: Early Rejection

What makes data filtering fundamentally different? It’s all about timing.

Data filtering takes a different approach by evaluating informativeness offline, before training begins. The goal is to identify and remove less informative samples to streamline the training process. But how can we be confident in making such early decisions?

Again, submodularity provides the answer. With (approximate) submodularity, we can be reasonably confident that samples deemed uninformative at the start won’t suddenly become highly informative later. This property gives us the theoretical backing to make these early filtering decisions.

Figure 2: Visualization of how sample informativeness evolves during active learning.

The Bayesian Perspective

Both active learning and data filtering can be understood through the lens of information theory, which provides a framework to quantify informativeness. Active learning methods like BALD (Houlsby et al, 20113) select samples that maximize the expected information gain, that is the mutual information between model parameters and predictions. Data filtering, meanwhile, can be viewed as removing samples that would contribute minimally to the posterior update.

More generally, many non-Bayesian methods can be interpreted as approximations of these Bayesian approaches (Kirsch et al, 20224). It would be interesting to explore this connection further.

A Practical Example: MNIST Active Learning

Let’s make this concrete with a real example. We conducted an active learning experiment on MNIST using a LeNet-5 model with Monte Carlo dropout. The experiment iteratively selects the most informative samples based on BALD scores and trains the model on the acquired samples.

The animation in Figure 3 shows the evolution of BALD scores of the training set using an Exponential Moving Average (EMA) over iterations using a fixed order of samples (sorted by initial informativeness). Because the EMA smooths out noise in the informativeness estimates, this animation is only illustrative of the true underlying dynamics. However, it is a reasonable practical choice for visualizing trends. For a more accurate visualization, one would want to use better estimators of the EIG scores or many more samples.

We see that the informativeness of samples decreases over time, as expected. We also see that the order of samples changes a lot as the previously most informative samples are added to the training set.

Figure 3: Evolution of BALD scores over time, smoothed using Exponential Moving Average (EMA).

Conclusion

We have examined a few key insights about the relationship between active learning and data filtering:

  1. The timing of selection decisions fundamentally shapes how these methods work;
  2. Submodularity provides theoretical backing for both approaches; and
  3. Real-world applications often require balancing theoretical ideals with practical constraints.

While submodularity is often assumed, examining cases where sample informativeness can increase under certain conditions instead of decrease might still be underexplored for filtering—at least in a principled fashion as far as I know.

There’s still much to explore about cases where sample informativeness behaves in unexpected ways:

  • Understanding when and why informativeness patterns deviate from theory;
  • Developing more robust methods that account for these deviations; and
  • Creating hybrid approaches that combine the selection and rejection approaches.

Conceptual Illustrations

The figures above illustrate the key concepts we’ve discussed. Let me explain how they were created:

Figure 1 shows a conceptual visualization of the informativeness curve, where samples are sorted by their informativeness (y-axis) along the x-axis. This curve demonstrates the diminishing returns property characteristic of submodularity. The visualization highlights:

  • The data filtering region (red shaded area) shows samples that would be filtered out because their informativeness falls below a threshold.
  • The active learning region (blue shaded area) represents samples that would be prioritized by active learning approaches because of their high informativeness.

Figure 2 is an animation that simulates how active learning works over time:

  1. We start with a distribution of samples with varying informativeness levels.
  2. At each step, the most informative sample (highest point) is selected for training.
  3. After selection, that sample’s informativeness drops to zero (it’s been “used”).
  4. The informativeness of other samples also decreases (with some random noise), reflecting how the value of remaining samples changes as the model learns.

This animation captures the dynamic, iterative nature of active learning, where the most informative sample is constantly changing as training progresses. It visually demonstrates why we need to re-evaluate informativeness at each step rather than making all selection decisions upfront.

Both visualizations were created using matplotlib with an XKCD-style aesthetic to make the concepts more approachable and intuitive.

Illustration Code
"""Generate an XKCD-style illustration of submodularity and how active learning vs. data filtering interact with the informativeness curve. The script produces both static PNG/SVG files and an animated GIF showing how informativeness changes over time as data gets trained on. Usage: python scripts/submodularity_curve.py # default path python scripts/submodularity_curve.py --out /tmp/figure.png """ # %% from __future__ import annotations import argparse from pathlib import Path import matplotlib.pyplot as plt import numpy as np from matplotlib.animation import FuncAnimation, PillowWriter # %% # Synthetic submodular-like curve: diminishing returns. x = np.linspace(1, 10000, 10000) y_base = 1 / (0.9 / 99.9 * x + (1 - 0.9 / 99.9)) y_filtering_threshold = 0.05 # filtering threshold: keep samples with y >= threshold. y_active_threshold = 0.2 filtering_color = "red" active_color = "blue" # First create the static plots with plt.xkcd(): fig, ax = plt.subplots(figsize=(8, 5)) # Plot the diminishing-returns curve. ax.plot(x, y_base, color="black") # Shaded region for samples that would be *filtered out* (y below threshold). ax.fill_between(x, 0, y_base, where=y_base < y_filtering_threshold, color=filtering_color, alpha=0.1, label="Data filtering region (offline)") # Horizontal dashed line showing the filtering threshold. ax.axhline(y_filtering_threshold, color=filtering_color, linestyle="--", linewidth=1) # Vertical dashed line showing the active learning region. ax.axvline(x[np.argmax(y_base <= y_active_threshold)], color=active_color, linestyle="--", linewidth=1) # Shaded vertical region for samples that would be considered for active learning. ax.fill_betweenx(y_base, 0, x, where=y_base > y_active_threshold, color=active_color, alpha=0.1, label="Active learning region (online)") # Styling. ax.set_xlabel("Samples sorted by informativeness") ax.set_ylabel("Informativeness") ax.set_title("Active Selection vs. Data Filtering") ax.set_ylim(0, 1.05) ax.set_xlim(0, x.max()) ax.legend() ax.grid(False) out_path = Path("active_vs_filtering_xkcd.png") out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, dpi=150, bbox_inches="tight") out_path = Path("active_vs_filtering_xkcd.svg") out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.show() #%% # Now create the animated version n_frames = 40 top_k = 1 current_y = y_base.copy() with plt.xkcd(): fig_anim, ax_anim = plt.subplots(figsize=(8, 5)) scatter = ax_anim.scatter([], [], color="black", s=1) selected_scatter = ax_anim.scatter([], [], color=active_color, s=100, alpha=0.1) # Horizontal dashed line showing the filtering threshold. ax_anim.axhline(y_filtering_threshold, color=filtering_color, linestyle="--", linewidth=1) # Styling ax_anim.set_xlabel("Samples sorted by informativeness") ax_anim.set_ylabel("Informativeness") ax_anim.set_title("Simulation of Top-1 Active Selection") ax_anim.set_ylim(0, 1.05) ax_anim.set_xlim(0, 1000) ax_anim.grid(False) def init(): scatter.set_offsets(np.c_[x, y_base]) return (scatter, selected_scatter) def animate(frame): print(frame) global current_y # Add noise to the decay if frame > 1: top_k_indices = np.argsort(current_y)[-top_k:] current_y[top_k_indices] = 0.0 # Decay the other points noise = np.clip(np.random.gumbel(0, 0.1, len(current_y)), 0, 1) current_y = current_y * (1 - noise) # Set the top 10 points to 0 top_k_indices = np.argsort(current_y)[-top_k:] # Append the top 10 points to the selected scatter selected_scatter.set_offsets(np.c_[x[top_k_indices].copy(), current_y[top_k_indices].copy()]) scatter.set_offsets(np.c_[x, current_y]) return (scatter, selected_scatter) anim = FuncAnimation(fig_anim, animate, init_func=init, frames=n_frames, interval=200, blit=False, save_count=40) fig_anim.tight_layout() writer = PillowWriter(fps=5) out_path = Path("active_selection_animation.gif") out_path.parent.mkdir(parents=True, exist_ok=True) anim.save(out_path, writer=writer) # %%

MNIST Active Learning Experiment

In our MNIST active learning experiment, we used a LeNet-5 model with Monte Carlo dropout to iteratively select the most informative samples based on BALD scores. The experiment started with a small initial labeled set and, at each iteration, acquired the most informative unlabeled samples for training. The BALD scores were calculated using multiple Monte Carlo samples to estimate the mutual information between model parameters and predictions.

The animation in Figure 3 visualizes the evolution of BALD scores over iterations. It shows how the informativeness of samples changes as the model learns, with the most informative samples being selected and their scores dropping to zero after acquisition. This dynamic process highlights the importance of re-evaluating sample informativeness during training.

The experiment was implemented in mnist_al_experiment.py, which handles the model training, BALD score calculation, and sample acquisition. The animation was created using bald_animation.py, which retrieves data from Weights & Biases (wandb) and generates the visualization of BALD score evolution.

Experiment Code
import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms from tqdm import tqdm import wandb # --- Configuration --- DEVICE = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) RANDOM_SEED = 42 torch.manual_seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) # MNIST Hyperparameters N_CLASSES = 10 IMG_WIDTH = 28 IMG_HEIGHT = 28 N_CHANNELS = 1 # LeNet Model Definition (LeNet-5 architecture with MC Dropout) class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=5) self.conv1_drop = nn.Dropout2d() self.conv2 = nn.Conv2d(32, 64, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(1024, 128) self.fc1_drop = nn.Dropout() self.fc2 = nn.Linear(128, N_CLASSES) def forward(self, input: torch.Tensor): input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2)) input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2)) input = input.view(-1, 1024) input = F.relu(self.fc1_drop(self.fc1(input))) input = self.fc2(input) return input def bootstrap_sample(dataset, sample_size=None, replace=True, random_state=None): """ Creates a bootstrap sample from the dataset. Args: dataset: PyTorch dataset to sample from sample_size: Size of the bootstrap sample (defaults to len(dataset)) replace: Whether to sample with replacement (True for bootstrap) random_state: Random seed for reproducibility Returns: A PyTorch Subset containing the bootstrapped samples """ if sample_size is None: sample_size = len(dataset) rng = np.random.default_rng(random_state) # Generate indices with replacement indices = rng.choice(len(dataset), size=sample_size, replace=replace) # Return a subset of the dataset with the selected indices return torch.utils.data.Subset(dataset, indices) def create_bootstrap_loader( dataset, batch_size=64, sample_size=None, replace=True, random_state=None, num_workers=0, shuffle=True, ): """ Creates a DataLoader with bootstrapped samples from the dataset. Args: dataset: PyTorch dataset to sample from batch_size: Batch size for the DataLoader sample_size: Size of the bootstrap sample (defaults to len(dataset)) replace: Whether to sample with replacement (True for bootstrap) random_state: Random seed for reproducibility num_workers: Number of worker processes for data loading shuffle: Whether to shuffle the data Returns: A DataLoader containing bootstrapped samples """ bootstrap_dataset = bootstrap_sample(dataset, sample_size, replace, random_state) return torch.utils.data.DataLoader( bootstrap_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, ) # EIG (BALD) Calculation def get_scores(model, data_loader, n_mc_samples, criterion=None): """ Calculates BALD scores for samples in data_loader. BALD = H(E[p(y|x,w)]) - E[H(p(y|x,w))] """ model.train() all_probs_mc = [] # To store [n_samples_dataset, n_mc_samples, n_classes] all_labels = [] with torch.inference_mode(): for data, labels in tqdm(data_loader, desc="Calculating BALD Scores"): data = data.to(DEVICE) batch_probs_mc = [] # Store MC samples for this batch [n_mc_samples, batch_size, n_classes] for _ in range(n_mc_samples): output = model(data) probs = F.softmax(output, dim=1) batch_probs_mc.append(probs.unsqueeze(0)) batch_probs_mc = torch.cat( batch_probs_mc, dim=0 ) # [n_mc_samples, batch_size, n_classes] all_probs_mc.append( batch_probs_mc.permute(1, 0, 2).cpu() ) # [batch_size, n_mc_samples, n_classes] all_labels.append(labels) # [batch_size] all_probs_mc_tensor = torch.cat( all_probs_mc, dim=0 ) # [total_samples, n_mc_samples, n_classes] all_labels = torch.cat(all_labels, dim=0) # [total_samples] # Entropy of mean predictions: H(E[p(y|x,w)]) mean_probs = torch.mean(all_probs_mc_tensor, dim=1) # [total_samples, n_classes] ic_probs = mean_probs * torch.log(mean_probs) ic_probs[torch.isnan(ic_probs)] = 0.0 entropy_of_mean = -torch.sum(ic_probs, dim=1) # [total_samples] # Mean of entropy of predictions: E[H(p(y|x,w))] ic_probs_mc = all_probs_mc_tensor * torch.log(all_probs_mc_tensor) ic_probs_mc[torch.isnan(ic_probs_mc)] = 0.0 entropy_per_mc_sample = -torch.sum( ic_probs_mc, dim=2 ) # [total_samples, n_mc_samples] mean_of_entropy = torch.mean(entropy_per_mc_sample, dim=1) # [total_samples] bald_scores = entropy_of_mean - mean_of_entropy acc = (mean_probs.argmax(dim=1) == all_labels).float().mean() * 100.0 if criterion is not None: loss = criterion(mean_probs.log(), all_labels) else: loss = -1.0 return bald_scores.numpy(), acc, loss # Training function def train_model(model, train_loader, optimizer, criterion, epochs): model.train() # Set to train mode (enables dropout, batchnorm updates etc.) pbar = tqdm(range(epochs), desc="Training Epochs") for epoch in pbar: epoch_loss = 0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(DEVICE), target.to(DEVICE) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() epoch_loss += loss.item() _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() pbar.set_postfix(loss=epoch_loss/len(train_loader), acc=100.*correct/total) def plot_scores(scores): # Sort scores descending sorted_scores = -np.sort(-scores) # Plot the top 100 scores plt.plot(sorted_scores) plt.show() # --- Data Loading and Preparation --- transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) # Load full training and test datasets full_train_dataset = datasets.MNIST( "./data", train=True, download=True, transform=transform ) # Subset the full training dataset to 10000 random samples # Use numpy's default_rng with fixed seed for reproducibility rng = np.random.default_rng(1) indices = rng.permutation(len(full_train_dataset))[:10000] full_train_dataset = torch.utils.data.Subset(full_train_dataset, indices) all_train_loader_for_scores = DataLoader( full_train_dataset, batch_size=1024, shuffle=False ) test_dataset = datasets.MNIST("./data", train=False, transform=transform) test_loader = DataLoader( test_dataset, batch_size=1024, shuffle=False ) # For final eval & test scores # --- Active Learning Setup --- N_INITIAL_LABELED = 60 # Number of initial labeled samples N_ACQUIRE_PER_ITER = 1 # Number of samples to acquire each iteration N_ACTIVE_LEARNING_ITERATIONS = 100 N_MC_SAMPLES_EIG = 32 # Number of MC samples for EIG TRAIN_EPOCHS_PER_ITER = 3 # Epochs to train model at each AL iteration LEARNING_RATE = 0.0005 BATCH_SIZE_TRAIN = 64 # Batch size for training N_MIN_SAMPLES_PER_EPOCH = 10_000 # Initialize wandb wandb.init( project="blog-active-learning-vs-filtering", config={ "initial_labeled_samples": N_INITIAL_LABELED, "acquire_per_iter": N_ACQUIRE_PER_ITER, "al_iterations": N_ACTIVE_LEARNING_ITERATIONS, "mc_samples_eig": N_MC_SAMPLES_EIG, "epochs_per_iter": TRAIN_EPOCHS_PER_ITER, "learning_rate": LEARNING_RATE, "batch_size": BATCH_SIZE_TRAIN, "model": "LeNet-5 with MC Dropout", "dataset": "MNIST", "device": str(DEVICE), } ) # Create initial labeled and unlabeled pools num_train_samples = len(full_train_dataset) all_indices = np.arange(num_train_samples) np.random.shuffle(all_indices) initial_labeled_indices = list(all_indices[:N_INITIAL_LABELED]) current_unlabeled_indices = list(all_indices[N_INITIAL_LABELED:]) # Store all acquisition scores # Format: list of dicts, each dict is one AL iteration # {'iteration': i, 'labeled_indices': [...], 'unlabeled_indices': [...], # 'scores_for_all_train_samples': np.array([...]), 'scores_for_test_samples': np.array([...])} all_scores_history = [] # --- Active Learning Loop --- print("Starting Active Learning Loop...") current_labeled_indices_set = set(initial_labeled_indices) for al_iteration in tqdm(range(N_ACTIVE_LEARNING_ITERATIONS + 1), desc="Active Learning Iterations"): # +1 for initial state and after last acquisition print(f"\n--- Active Learning Iteration: {al_iteration} ---") print(f"Currently labeled samples: {len(current_labeled_indices_set)}") # 1. Create current labeled dataset and loader labeled_subset = Subset(full_train_dataset, list(current_labeled_indices_set)) labeled_loader = create_bootstrap_loader( labeled_subset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, random_state=RANDOM_SEED, sample_size=max(N_MIN_SAMPLES_PER_EPOCH, len(labeled_subset)), replace=N_MIN_SAMPLES_PER_EPOCH > len(labeled_subset), ) # 2. Initialize or re-initialize model and optimizer model = LeNet().to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) criterion = nn.CrossEntropyLoss() # 3. Train model on current labeled data (unless it's iteration 0 before any acquisition) if N_INITIAL_LABELED > 0: # Train if there's data print("Training model...") train_model( model, labeled_loader, optimizer, criterion, epochs=TRAIN_EPOCHS_PER_ITER ) # 4. Calculate EIG scores for all original training samples and test samples print("Calculating EIG scores for test samples...") scores_test_samples, test_acc, test_loss = get_scores(model, test_loader, N_MC_SAMPLES_EIG, criterion) print(f"Model trained. Test Acc: {test_acc:.2f}%, Test Loss: {test_loss:.4f}") print("Calculating EIG scores for all original training samples...") # Create a loader for ALL original training samples to get their scores scores_all_train_samples, _, _ = get_scores( model, all_train_loader_for_scores, N_MC_SAMPLES_EIG ) # Plot the scores plt.figure(figsize=(10, 5)) plt.title("BALD Scores for all original training samples") plt.ylabel("BALD Score") plot_scores(scores_all_train_samples) plt.show() plt.figure(figsize=(10, 5)) plt.title("BALD Scores for all test samples") plt.ylabel("BALD Score") plot_scores(scores_test_samples) plt.show() # Store scores for this iteration iteration_scores_data = { "iteration": al_iteration, "current_labeled_indices": sorted(list(current_labeled_indices_set)), "current_unlabeled_indices": sorted( current_unlabeled_indices ), # These are indices FROM the original full_train_dataset "scores_for_all_train_samples": scores_all_train_samples, # Order matches full_train_dataset "scores_for_test_samples": scores_test_samples, # Order matches test_dataset } all_scores_history.append(iteration_scores_data) print(f"Scores stored for iteration {al_iteration}.") # Log metrics to wandb wandb_log_data = { "iteration": al_iteration, "num_labeled_samples": len(current_labeled_indices_set), "test_accuracy": test_acc, "test_loss": test_loss, "avg_train_bald_score": np.mean(scores_all_train_samples), "avg_test_bald_score": np.mean(scores_test_samples), "max_train_bald_score": np.max(scores_all_train_samples), "max_test_bald_score": np.max(scores_test_samples), } # Log histograms of BALD scores wandb.log(wandb_log_data, commit=False) # Log histograms of BALD scores (optional) # Create a table with all the scores train_scores_table = wandb.Table(columns=["index", "bald_score"]) for idx, score in enumerate(scores_all_train_samples): train_scores_table.add_data(idx, float(score)) test_scores_table = wandb.Table(columns=["index", "bald_score"]) for idx, score in enumerate(scores_test_samples): test_scores_table.add_data(idx, float(score)) # Create plotly figures for BALD scores # Create sorted scores for better visualization sorted_train_scores = sorted(scores_all_train_samples, reverse=True) sorted_test_scores = sorted(scores_test_samples, reverse=True) # Create dataframes for plotly train_df = pd.DataFrame({ "index": range(len(sorted_train_scores)), "bald_score": sorted_train_scores }) test_df = pd.DataFrame({ "index": range(len(sorted_test_scores)), "bald_score": sorted_test_scores }) # Create plotly figures train_fig = px.line(train_df, x="index", y="bald_score", title=f"Sorted BALD Scores for Training Samples (Iteration {al_iteration})") train_fig.update_layout(xaxis_title="Sample Index (sorted)", yaxis_title="BALD Score") test_fig = px.line(test_df, x="index", y="bald_score", title=f"Sorted BALD Scores for Test Samples (Iteration {al_iteration})") test_fig.update_layout(xaxis_title="Sample Index (sorted)", yaxis_title="BALD Score") # Log plotly figures to wandb wandb.log({ "train_bald_scores_plot": wandb.Plotly(train_fig), "test_bald_scores_plot": wandb.Plotly(test_fig), "train_bald_scores": train_scores_table, "test_bald_scores": test_scores_table, }, commit=False) # 5. If it's not the last iteration, perform acquisition if al_iteration < N_ACTIVE_LEARNING_ITERATIONS: if not current_unlabeled_indices: print("No more unlabeled samples to acquire. Stopping.") break print("Acquiring new samples...") # We need scores only for the *currently unlabeled* samples to decide which to pick # `scores_all_train_samples` contains scores for *all* original training samples. # We filter these down to only the unlabeled ones. unlabeled_scores_map = { idx: scores_all_train_samples[idx] for idx in current_unlabeled_indices } # Sort unlabeled samples by their EIG score (descending) sorted_unlabeled_by_score = sorted( unlabeled_scores_map.items(), key=lambda item: item[1], reverse=True ) # Select top N_ACQUIRE_PER_ITER samples num_to_acquire = min(N_ACQUIRE_PER_ITER, len(sorted_unlabeled_by_score)) acquired_indices_scores = sorted_unlabeled_by_score[:num_to_acquire] acquired_indices = [idx for idx, score in acquired_indices_scores] acquired_scores = [score for idx, score in acquired_indices_scores] if not acquired_indices: print( "Could not acquire any new samples (perhaps scores were uniform or list was empty). Stopping." ) break print(f"Acquired {len(acquired_indices)} new samples ({np.mean(acquired_scores):.2f} avg score).") # Log acquisition information wandb.log({ "acquisition_step": al_iteration, "acquired_indices": acquired_indices, "acquired_scores": acquired_scores, "avg_acquired_score": np.mean(acquired_scores) if acquired_scores else 0, }, commit=False) # Add to labeled set and remove from unlabeled pool for idx in acquired_indices: current_labeled_indices_set.add(idx) current_unlabeled_indices.remove( idx ) # Keep this list of original indices up to date else: print("Final iteration reached. No more acquisitions.") wandb.log({}, step=al_iteration, commit=True) # Finish wandb run wandb.finish()

📊 View Experiment Results on W&B

Animation Code
#!/usr/bin/env python3 """ Create an animation showing how BALD scores evolve during active learning. This script retrieves data from wandb and creates a similar animation to the one in visualizations.py. Uses BALD scores without sorting to show the actual distribution of uncertainty. """ #%% import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter from pathlib import Path import json import pandas as pd from tqdm.auto import tqdm from wandb.apis.public import Api as WandbApi # Initialize wandb API api = WandbApi() # Find the latest run in the project runs = api.runs("blog-active-learning-vs-filtering", order="-created_at") latest_run = runs[0] # Most recent run print(f"Analyzing run: {latest_run.name} ({latest_run.id})") #%% # Get data from the run - we need to extract BALD scores and acquisition info scores_by_iter = {} acquired_indices_by_iter = {} train_subset_indices = {} # To store which indices were already labeled # Fetch the history data print("Fetching history data...") scores_key = "train_bald_scores" history = latest_run.history(keys=["iteration", scores_key], pandas=False) #%% # Process the history data for row in tqdm(history): if "iteration" not in row: continue iteration = row["iteration"] # Get BALD scores if available in this row if scores_key in row: try: # Try to get scores from the run table table_data = row[scores_key] # Load the artifact table_file = latest_run.file(table_data['path']).download(replace=True).read() table_data = json.loads(table_file) df = pd.DataFrame(table_data["data"], columns=table_data["columns"]) scores_by_iter[iteration] = df["bald_score"].values except Exception as e: print(f"Error extracting {scores_key} for iteration {iteration}: {e}") # Get acquired indices if available if "acquired_indices" in row: acquired_indices_by_iter[iteration] = row["acquired_indices"] # Get currently labeled indices if "current_labeled_indices" in row: train_subset_indices[iteration] = row["current_labeled_indices"] #%% # If we still don't have data, exit if not scores_by_iter: print(f"Failed to retrieve {scores_key} from the run. Exiting.") raise ValueError("Failed to retrieve BALD scores from the run.") #%% # Sort iterations and prepare animation data sorted_iterations = sorted(scores_by_iter.keys()) all_scores = [scores_by_iter[i].copy() for i in sorted_iterations] acquired_indices = [] for scores in all_scores: for i in acquired_indices: scores[i] = 0. # Get the argmax score max_idx = np.argmax(scores) acquired_indices.append(max_idx.item()) # Pop last acquired index acquired_indices.pop() assert np.all(all_scores[-1][acquired_indices] == 0.) print(f"Acquired indices: {acquired_indices}") print(f"Found {scores_key} for {len(all_scores)} iterations") print(f"First iteration has {len(all_scores[0])} samples") #%% # Create the animation with XKCD style but using EMA of the scores ema_decay = 0.9 all_scores_ema = [] current_scores = all_scores[0].copy() for i in range(len(all_scores)): current_scores = ema_decay * current_scores + (1.0 - ema_decay) * all_scores[i] for j in range(i): current_scores[acquired_indices[j]] = 0. all_scores_ema.append(current_scores) #%% with plt.xkcd(): fig, ax = plt.subplots(figsize=(10, 6)) # Set up initial plot - we'll use scatter plot for non-sorted scores max_points = min(10000, max([len(scores) for scores in all_scores_ema])) scatter = ax.scatter([], [], color="black", s=2) # For highlighted points we'll use a different color highlight_scatter = ax.scatter([], [], color="blue", s=100, alpha=0.5) # Title and labels ax.set_title("BALD Score Evolution (Non-sorted w/ EMA)") ax.set_xlabel("Sample Index") ax.set_ylabel("Informativeness (BALD score)") ax.set_xlim(0, max_points) # Find max BALD score for consistent y-axis scaling max_score = max([max(scores) for scores in all_scores_ema]) ax.set_ylim(0, max_score * 1.1) # Text annotations iter_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12) labeled_text = ax.text(0.02, 0.90, '', transform=ax.transAxes, fontsize=10) # Get the sorted indices of the first iteration sorted_indices = np.argsort(all_scores[0])[::-1] inv_sorted_indices = np.argsort(sorted_indices) def init(): # Use the raw scores (no sorting) current_scores = np.array(all_scores_ema[0]) x_data = np.arange(len(current_scores)) scatter.set_offsets(np.c_[x_data, current_scores]) # Highlight top uncertain points highlight_scatter.set_offsets(np.c_[[], []]) # Set iteration text iter_text.set_text(f"Iteration: {sorted_iterations[0]}") return scatter, highlight_scatter, iter_text, labeled_text def animate(frame_idx): if frame_idx < len(sorted_iterations): iteration = sorted_iterations[frame_idx] # Compute an EMA of the past scores current_scores = np.array(all_scores_ema[frame_idx]) # Use raw scores without sorting x_data = inv_sorted_indices[np.arange(len(current_scores))] # Update main scatter plot scatter.set_offsets(np.c_[x_data, current_scores.copy()]) # Highlight top uncertain points (these will change each iteration) # top_indices = acquired_indices[:frame_idx] # highlight_scatter.set_offsets(np.c_[inv_sorted_indices[top_indices], current_scores[top_indices].copy()]) # assert np.all(current_scores[top_indices[:1]] == 0.), current_scores[top_indices] # current_scores[top_indices] = 0. # Update text iter_text.set_text(f"Iteration: {iteration}") return scatter, highlight_scatter, iter_text, labeled_text # Create animation num_frames = len(sorted_iterations) anim = FuncAnimation(fig, animate, init_func=init, frames=num_frames, interval=50, blit=True) # Add legend ax.legend() # Save animation writer = PillowWriter(fps=1000/50) output_path = Path("bald_scores_ema_animation.gif") print(f"Saving animation to {output_path}") anim.save(output_path, writer=writer) print("Done!") # %%
Read Entire Article