Break the quadratic wall of Transformer attention: WERSA, paper+code open source

3 months ago 1

This repository provides the official implementation of WERSA, a novel attention mechanism with linear O(n) time complexity, designed to scale Transformer models to very long sequences without a performance trade-off.

Our paper, "Scaling Attention to Very Long Sequences in Linear Time with Wavelet-Enhanced Random Spectral Attention (WERSA)", is available on arXiv:2507.08637.

🔬 The Science Behind WERSA

Standard attention mechanisms have a quadratic (O(n²)) complexity that makes processing long sequences impractical. WERSA solves this by combining several powerful principles to achieve linear (O(n)) efficiency while maintaining high performance.

  • Multi-Resolution Analysis: Uses Haar wavelet transforms to decompose the input into multiple scales, capturing both local details and global context.
  • Adaptive Filtering: An MLP generates input-dependent filters and learnable scale_weights modulate each wavelet level, allowing the model to dynamically prioritize the most informative frequency components.
  • Linear Complexity via Random Features: Uses random feature projection to approximate the softmax kernel, avoiding the computation of the full quadratic attention matrix.

⚙️ Installation

First, ensure you have PyTorch and Hugging Face Transformers installed. Then, install the wersa package directly from this repository.

pip install torch --index-url https://download.pytorch.org/whl/cu121 pip install transformers pip install git+https://github.com/vincenzodentamaro/wersa.git

🚀 Quickstart: Building a Qwen-like Model with WERSA

You can easily build a Qwen-style causal language model with WERSA attention by importing the WersaConfig and WersaForCausalLM classes from the package.

Building an 8B Parameter Model

This snippet creates an ~8B parameter model with a configuration similar to state-of-the-art models like Qwen2-7B.

from wersa import WersaConfig, WersaForCausalLM from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B") config_8b = WersaConfig( vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, hidden_size=4096, num_hidden_layers=32, num_attention_heads=32, intermediate_size=11008, max_position_embeddings=4096 ) model_8b = WersaForCausalLM(config_8b) print(f"8B Model created with ~{model_8b.num_parameters() / 1e9:.2f}B parameters.")

Building a 0.6B Parameter Model

This snippet creates a smaller ~0.6B parameter model, perfect for faster experiments or deployment on more constrained hardware.

from wersa import WersaConfig, WersaForCausalLM from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B") config_0_6b = WersaConfig( vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=2816, max_position_embeddings=1024 ) model_0_6b = WersaForCausalLM(config_0_6b) print(f"0.6B Model created with ~{model_0_6b.num_parameters() / 1e9:.2f}B parameters.")

📖 Training and Examples

This repository includes complete scripts to demonstrate how to pre-train these models from scratch and test their generation capabilities.

  • train_and_generate_1b.py: A full example for training a ~1B parameter model.
  • train_and_generate_8b.py: A full example for training the 8B parameter model.

📜 Citation

If you find WERSA useful in your research, please consider citing our paper:

@misc{dentamaro2025scaling, title={Scaling Attention to Very Long Sequences in Linear Time with Wavelet-Enhanced Random Spectral Attention (WERSA)}, author={Vincenzo Dentamaro}, year={2025}, eprint={2507.08637}, archivePrefix={arXiv}, primaryClass={cs.LG} }

📄 License

This project is licensed under the Apache License 2.0.

Read Entire Article