- Remove custom quant cache layer stuff for now (cache quant needs to be tested with all the new changes)
- Move preprocessing to separate util module
- Replace dedicated Gemma4 modules with existing generic modules, make necessary adjustments:
- SDPA fallback triggers whenever head_dim > 512 (xformers also added, but its GQA impl. is buggy and needs an annoying workaround that slows it down a lot)
- Add necessary extra norms, new transpose args and second residual channel to BlockSparseMLP (dense_mlp becomes shared expert instead)
- Add layer scalar per decoder block
- Don't apply embedding multiplier to embedded MM tokens
- Ensure embedding scaling exactly matches HF bfloat16 version
Vision stuff:
- Handle non-causal attention in multimodal spans with multiple (flash) attn passes rather than custom mask.
- Avoid extending chunk size past the first MM span (allow small amount of redundant processing to keep VRAM overhead relatively constant.)
- Fold Gemma4VisionStandardize into Gemma4VisionPooler
- Replace Gemma4VisionProjector with RMSNorm+Linear modules
- Use 2D RoPE in kernel instead of precomputed sin,cos tensors
- Use non-causal attention with no mask (HF reference pads all embeddings to the same size of 280 tokens and then has to apply a custom attn mask to make that work, but the padding tokens are discarded anyway so there's no point)