I constantly find myself having to print raw tensors (either by hand or dumping them to stdout), especially when reading pytorch / jax code, to understand the transformations, e.g. for something like,
`x = torch.randn(32, 3, 224, 224).unfold(2, 16, 16).unfold(3, 16, 16).reshape(32, 3, 196, 256).transpose(1, 2).reshape(32, 196, 768).view(32, 196, 12, 64).transpose(1, 2)`
How do folks visualize tensors to quickly understand data flow in complex NN's?