Cartridges: Storing long contexts in tiny caches with self-study

4 months ago 11

Cartridges logo

Storing long contexts in tiny KV caches with self-study.

GitHub arXiv

What is this? This repository provides code for training a cartridge, a small KV cache representing a large dump of text, using a test-time training recipe called self-study. The code is based on our paper Cartridges: Lightweight and general-purpose long context representations via self-study.

tl;dr When we put lots of text (e.g. a whole code repo) into a language model's context, generation cost soars because of the KV cache's size. What if we trained a smaller KV cache for our documents offline? Using a test-time training recipe called self-study, we show that this simple idea can improve throughput by 26× while maintaining quality. (See our blogpost for more.)

Table of contents

Step 1: Clone the repository and install the Python package.

git clone https://github.com/HazyResearch/cartridges && cd cartridges pip install uv uv pip install -e .

Step 2: Set some environment variables

The codebase relies on your setting the following variables. We recommend adding them to your ~/.bashrc, ~/.zshrc, DockerFile, etc.

# path to your the directory where you cloned this repo export CARTRIDGES_DIR=/path/to/cartridges # path to a directory where you want to store outputs like models checkpoints and such export CARTRIDGES_OUTPUT_DIR=/path/to/cartridges/outputs

Synthesizing Training Data with Self-Study

What is self-study? Self-study is a test-time training approach where we generate synthetic conversations about a corpus of text. The process simulates two AI agents: one asks questions or makes requests about the content, and another responds using the provided context. This creates training data that teaches the model to efficiently compress and retrieve information from long contexts.

Quickstart: Take a look at the script at scripts/longhealth_synthesize.py for an example of how to generate training data with self-study. To actually run the script, you will need to spin up an inference server (either Tokasaurus or SGLang) and set the client variable to point to it.

Below we walk through the process of generating synthetic training data for a corpus of text in more detail. As a running example, we'll be training a cartridge on our paper on Cartridges. How meta! Here are the steps:

  1. Create a StructuredContext object that contains the data you want to store in the cartridge
  2. Ensure you have an inference server running (either Tokasaurus or SGLang) and configure your client to point to it
  3. Instantiate a SynthesizeConfig object that contains the parameters for the self-study process
  4. Put it all together in one script and run it!

Note: For configuration, we use Pydantic models. Pydantic models are useful for defining the schema of the config and quickly ensuring that the config is valid at the beginning of the script. We also rely on pydrantic, which provides a few utilities for working with configs.

Step 1: Create a Context Object

A StructuredContext represents your corpus in a format that the self-study process can work with. We provide several built-in context types. For our example, we'll use the TexDocument context type.

from cartridges.contexts.tex import TexDocument context_config = TexDocument.Config( arxiv_src_url="https://arxiv.org/src/2506.06266", main_file="main.tex" )

We provide a few other context types including HTMLDocument, TexDocument. Can also use an arbitrary JSON object as a context.

Step 2: Prepare an Inference Server

Self-study requires an inference server to generate the synthetic conversations. We support two options:

  • Tokasaurus (recommended) - We ran all of our experiments with Tokasaurus, which provides higher throughput generation and is easier to modify.
  • SGLang - We're also providing support for SGLang, but we have not tested it extensively.
Option A: Modal Deployment (Tokasaurus)

We found it easy to run data generation with Modal's serverless horizontal scaling.

For cloud deployment, you can deploy on Modal:

modal deploy infra/modal_deploy_tksrs.py

Then configure with the modal URL:

from cartridges.clients.tokasaurus_batch import TokasaurusBatchClient client_config = TokasaurusBatchClient.Config( url="https://your-modal-deployment-url.modal.run", model_name="meta-llama/Llama-3.2-3B-Instruct" )

Option B: Local deployment (Tokasaurus)

If you have access to GPUs, you can run also run a local Tokasaurus server:

  1. Clone and install Tokasaurus:
git clone https://github.com/ScalingIntelligence/tokasaurus cd tokasaurus git checkout --track origin/add-top-logprobs uv pip install -e .
  1. Start the server:
tksrs model=meta-llama/Llama-3.2-3B-Instruct kv_cache_num_tokens='(512 * 1024)' max_top_logprobs=5
  1. Configure your client:
from cartridges.clients.tokasaurus_batch import TokasaurusBatchClient client_config = TokasaurusBatchClient.Config( port=8001, # Default Tokasaurus port model_name="meta-llama/Llama-3.2-3B-Instruct" )
Option C: Modal deployment (SGLang)

We found it easiest to run data generation with Modal because it provides serverless horizontal scaling.

For cloud deployment, you can deploy on Modal:

modal deploy infra/modal_deploy_sglang.py

Then configure with the modal URL:

from cartridges.clients.sglang import SGLangClient client_config = SGLangClient.Config( url="https://your-modal-deployment-url.modal.run", model_name="meta-llama/Llama-3.2-3B-Instruct" )
Option D: Local deployment (SGLang)
  1. Install and launch a SGLang server following the instructions here.
  2. Configure your client:
from cartridges.clients.sglang import SGLangClient client_config = SGLangClient.Config( model_name="meta-llama/Llama-3.2-3B-Instruct", url="http://localhost:8000", )

Step 3: Configuring the Synthesizer and Putting it all together

We are now going to put all of the pieces together in a SynthesizeConfig object that configures the entire self-study process.

Core Settings:

  • num_samples: Total number of training examples to generate
  • batch_size: Number of training examples to generate per call to the inference server.
  • max_num_batches_in_parallel: Number of batches to process concurrently. When using Modal, high values

Here's a complete example script:

import os from pathlib import Path import pydrantic from pydrantic.variables import FormatStringVariable from cartridges.clients.tokasaurus_batch import TokasaurusBatchClient from cartridges.synthesize import SynthesizeConfig from cartridges.synthesizers.self_study import SelfStudySynthesizer, SlicePromptSamplerWithChunks from cartridges.utils import WandBConfig from cartridges.tasks.longhealth.context import LongHealthStructuredContextConfig client_config = TokasaurusBatchClient.Config( url="https://hazyresearch--tksrs-entry-capsules-3b-1xh100-min0-max64-serve.modal.run", ports=None, model_name="meta-llama/Llama-3.2-3B-Instruct", ) context_config = TexDocument.Config( arxiv_src_url="https://arxiv.org/src/2506.06266", main_file="main.tex" ) config = SynthesizeConfig( context=context_config, synthesizer=SelfStudySynthesizer.Config( client=client_config, tokenizer="meta-llama/Llama-3.2-3B-Instruct", max_rounds=1, prompt_sampler=SlicePromptSamplerWithChunks.Config( slices=["structuring", "summarization", "question", "use_case", "creative"], min_chunk_size=512, max_chunk_size=4096, desc=f"Below is a research paper on test-time training for long contexts." ), prob_cot_a=0.2, use_tools=False, tools=[] ), output_dir=os.environ.get("CARTRIDGES_OUTPUT_DIR", "."), num_samples=512, batch_size=16, max_num_batches_in_parallel=4, handle_exceptions=True, # Continue if individual batches fail save_wandb_artifact=True, name="cartridges-tutorial", wandb=WandBConfig(project="cartridges", entity="hazy-research"), ) if __name__ == "__main__": pydrantic.main([config])

Step 4: Running the Synthesis

Once you've created the file, run it with:

python your_synthesis_script.py

Once the run is complete, it will save the results to a pickle file and print the path:

Final output saved to /path/to/output/dir/artifact/dataset.pkl
Output format
class TrainingExample(BaseModel): messages: List[Message] # The conversation between agents (system, user, assistant format) token_ids: List[int] # The token IDs for the response top_logprob_ids: List[List[int]] # The top-k token predictions at each position top_logprob_logprobs: List[List[float]] # The corresponding log probabilities metadata: Dict[str, Any] # Information about tool usage, prompts, and generation process
Exploring synthesized dataset in a DataFrame
import pickle import pandas as pd # Load the dataset with open("/path/to/output/dir/artifact/dataset.pkl", "rb") as f: data = pickle.load(f) rows = data["rows"] context = data["context"] # Convert to DataFrame for exploration df = pd.DataFrame([ { "num_messages": len(row.messages), "num_output_tokens": row.num_output_tokens, "seed_prompt": row.metadata.get("seed_prompt", ""), "conversation": "\n".join([f"{msg.role}: {msg.content}" for msg in row.messages]) } for row in rows[:10] # First 10 examples ])

You can enhance the self-study process with tools that allow agents to dynamically retrieve additional context:

from cartridges.tools.base import Tool # Define custom tools for information retrieval tools = [ SearchTool.Config(description="Search for specific information"), SummaryTool.Config(description="Generate summaries of sections") ] synthesizer_config = SelfStudySynthesizer.Config( # ... other config ... use_tools=True, tools=tools )

Quickstart: Take a look at the script at scripts/longhealth_train.py for an example of how to generate training data with self-study.

See cartridges.train.TrainConfig for the schema of the main config we use for training.

Below we provide an example of a config file prefaced with notes describing each part of the config:

  • dataset Th
  • *`
import os from pathlib import Path import pydrantic from cartridges.initialization.strategies.first_n_tokens import KVCacheInitFromFirstNTokensOfContext from cartridges.train import EvalDatasetConfig, GenerateDatasetConfig, TrainConfig from cartridges.config import HFModelConfig from cartridges.datasets import CartridgeDataset from cartridges.tasks.longhealth import LongHealthMultipleChoiceGenerateDataset from cartridges.utils import WandBConfig file_name = Path(__file__).stem config = TrainConfig( model=HFModelConfig( pretrained_model_name_or_path="meta-llama/Llama-3.2-3B-Instruct", model_cls=LlamaForCausalLM, attn_implementation="einsum", ), kv_cache_initializer=KVCacheInitFromFirstNTokensOfContext.Config(max_tokens=2048), lr=2e-2, loss_type="logits", epochs=2, global_batch_size=bs, local_batch_size=4, use_batch_sampler=True, dataset=CartridgeTrainDataset.Config( # path should point to the output of the synthesis script we ran above data_sources=[("/path/to/output/dir/artifact/dataset.pkl", None)] max_sequence_length=1024, is_wandb=True, label_type="logits", top_k_logits=20, ), context=LongHealthStructuredContextConfig(patient_ids=patient_ids), save_every_n_steps=512, generate_every_n_steps=512, generate_max_new_tokens=512, generate_datasets=[ GenerateDatasetConfig( dataset=LongHealthMultipleChoiceGenerateDataset.Config( patient_ids=patient_ids, cot=True, ), name_for_wandb=f"longhealth_mc", num_samples=8, num_samples_final=8, batch_size=16, temperature=0.3 ) ], eval_every_n_steps=256, eval_datasets=[], distributed_backend="gloo", wandb=WandBConfig( project="cartridges", tags=["train", "longhealth", f"patients{patients_str}"], entity="hazy-research", ), output_dir=os.environ["CARTRIDGES_OUTPUT_DIR"], name="train-cartridges" ) if __name__ == "__main__": pydrantic.main([config])

Distributed data parallel training

To launch a data parallel training run, you can run:

torchrun --standalone --nproc_per_node=2 path/to/file.py

We describe two ways to serve and chat with a trained Cartridge: a simple, but slow way that just uses a pure PyTorch generation loop, and a faster one that uses a Tokasaurus server.

Serving with Tokasuaurus [Fastest and recommended]

We've implemented (h/t @geoffreyangus) an integration with Tokasaurus, a simple LLM inference server optimized for high throughput.

To run the Tokasaurus server, you will need to (install Tokasaurus from source)[], switch to the branch geoff/cartridges, and then follow the instructions here to make API calls to the server.

We

client = TokasaurusClient( url="https://your-modal-deployment-url.modal.run", model_name="meta-llama/Llama-3.2-3B-Instruct" )
streamlit run cartridges/analysis/dashboards/chat_w_cache.py

Serving with Basic PyTorch [Easiest but slow]

streamlit run cartridges/analysis/dashboards/chat_w_cache.py

Acknowledgments and Citation

There are tons of people and organizations who have supported this project. Below we shout out a few, but check out the the paper for a full list.

The compute for this project was provided by Modal — who made it super easy to scale out horizontally when running the synthetic data generation for self-study — and Together — who provided the compute for training the Cartridges on the synthetic data. Prime Intellect, Voltage Park, and Azure through the HAI Grants program also contributed compute towards this project.

@article{eyuboglu2025cartridges, title={Cartridges: Lightweight and general-purpose long context representations via self-study}, author={Eyuboglu, Sabri and Ehrlich, Ryan and Arora, Simran and Guha, Neel and Zinsley, Dylan and Liu, Emily and Tennien, Will and Rudra, Atri and Zou, James and Mirhoseini, Azalia and others}, journal={arXiv preprint arXiv:2506.06266}, year={2025} }
Read Entire Article