A PyTorch-like deep learning framework implementation for educational purposes. This project implements a complete neural network library with automatic differentiation, common layer types, optimizers, and example implementations.
- Tensor operations with automatic differentiation
- Neural network modules (Linear, ReLU, etc.)
- Optimizers (SGD, Adam)
- MNIST classification example with 97.16% accuracy
- Comprehensive evaluation metrics
- Clone the repository:
- Install uv (if not already installed):
- Create and activate a virtual environment with dependencies:
- Download MNIST Dataset: The MNIST dataset files are not included in the repository due to their size. You'll need to:
a. Download the CSV files from Kaggle: MNIST CSV Format b. Place the files in the archive directory:
- archive/mnist_train.csv (for training data)
- archive/mnist_test.csv (for test data)
-
Tensor Operations (mytorch/__init__.py)
- Basic arithmetic operations (+, -, *, /)
- Matrix operations (dot product, transpose)
- Automatic differentiation with backward pass
- Gradient computation and accumulation
-
Neural Network Modules (mytorch/nn.py)
- Linear layer with weights and biases
- Activation functions (ReLU)
- Loss functions (Cross Entropy)
- Forward and backward propagation
-
Optimizers
- SGD (Stochastic Gradient Descent)
- Adam optimizer implementation
The framework follows a modular design similar to PyTorch:
Run the training script:
The script will:
- Load and preprocess MNIST data
- Create a neural network model
- Train for specified epochs
- Save model weights
Run the evaluation script:
- Accuracy: 97.16%
- Macro Precision: 97.23%
- Macro Recall: 97.14%
- Macro F1-score: 97.15%
0 | 98.28% | 98.98% | 98.63% |
1 | 97.84% | 99.56% | 98.69% |
2 | 98.87% | 93.60% | 96.17% |
3 | 90.64% | 98.81% | 94.55% |
4 | 97.95% | 97.45% | 97.70% |
5 | 96.26% | 98.09% | 97.17% |
6 | 98.11% | 97.49% | 97.80% |
7 | 97.86% | 97.67% | 97.76% |
8 | 98.28% | 93.63% | 95.90% |
9 | 98.18% | 96.13% | 97.15% |
-
Strong Overall Performance
- 97.16% accuracy on test set
- Consistent performance across classes
- High precision and recall for most digits
-
Per-class Analysis
- Best performance on digit 1 (F1: 98.69%)
- Most challenging: digit 3 (F1: 94.55%)
- Very high precision across all classes (>90%)
-
Error Patterns
- Most confusion between visually similar digits (2↔3, 4↔9)
- Minimal confusion between dissimilar digits
- Balanced error distribution
Feel free to open issues or submit pull requests for improvements or bug fixes. Areas for potential enhancement:
- Additional layer types (Conv2D, MaxPool, etc.)
- More optimization algorithms
- Data augmentation techniques
- Support for other datasets
- Performance optimizations
MIT License - feel free to use this code for educational purposes.