HloModule jit_splash_attention_kernel, entry_computation_layout={(bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)})->bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.28 (Arg_0.1: bf16[8,2048,128], Arg_1.2: bf16[8,2048,128], Arg_2.3: bf16[8,2048,128]) -> bf16[8,2048,128] {
%constant.4 = s8[1,2,2]{2,1,0} constant({ { { 0, 0 }, { 0, 1 } } })
%constant.5 = s8[1,2,2]{2,1,0} constant({ { { 1, 0 }, { 2, 1 } } })
%Arg_0.1 = bf16[8,2048,128]{2,1,0} parameter(0), metadata={op_name="q"}
%Arg_1.2 = bf16[8,2048,128]{2,1,0} parameter(1), metadata={op_name="k"}
%Arg_2.3 = bf16[8,2048,128]{2,1,0} parameter(2), metadata={op_name="v"}
%constant.6 = s32[2048]{0} constant({...})
%broadcast.0 = s32[2048,128]{1,0} broadcast(%constant.6), dimensions={0}, metadata={op_name="jit(splash_attention_kernel)/jit(main)/splash_attention_kernel/splash_kernel_b1024_h8_s2048/jit(_splash_attention)/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=1045}
%custom-call.0 = (f32[1024,128]{1,0}, f32[1024,128]{1,0}, f32[1024,128]{1,0}, bf16[8,2048,128]{2,1,0}) custom-call(%constant.4, %constant.5, %Arg_0.1, %Arg_1.2, %Arg_2.3, /*index=5*/%broadcast.0), custom_call_target="tpu_custom_call", operand_layout_constraints={s8[1,2,2]{2,1,0}, s8[1,2,2]{2,1,0}, bf16[8,2048,128]{2,1,0}, bf16[8,2048,128]{2,1,0}, bf16[8,2048,128]{2,1,0}, s32[2048,128]{1,0}}, metadata={op_name="jit(splash_attention_kernel)/jit(main)/splash_attention_kernel/splash_kernel_b1024_h8_s2048/jit(_splash_attention)/splash_mha_fwd_c602b16d/pallas_call" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=1100}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjEuMC4wZ2l0AAFRCwEDBQcJAQMLAzsNDxETFRcZGx0fISMlJykrLS8xMzU3OTs9P0FDRQMKCWYILwH7BwsXCxMXCxMLCwsbCwsLCysLCxMTExcLCwsLCw8PDw8XGxMTFxNzUwsLExcXFxcXCxMLExMTExMTCxMTExcPCxMTExMTEwsTDxcTCxMTExMPCxcPDxMXCxMXExMTExMPpYUPCwsLCwsLCwsLFxMPExcPExMLExMPExMTGwsFD2GBkY1hKgIqAgFGBgsLExMTExcTFwsLCwsTExMTHxcXCwszFwsLFxMXExMTIwsLawsbCwtzCw8LC0sfCx8LHwsfCycLJwsnCx8LGxsbGxsbGxsPDw8fCycLJxMLJxMLJxMLJxMLGxMLGxMfFwsfExMLHw8fExMTEw8PHw8fExMPHw8LJxMTDw8PExMPExMfDxMfDw8nExMPDw8TEw8PHxMTHw8THxMTHxcfCxMTHwsTExMfExMfExMTHxMTEx8TEx8rUw8TEx8TEx9TExMfExMfExMfExMfCw8LHxMTEx8PCxMTHxMTHxMTHxMTHxcfExMfMxMTHxMTHxMTEx8TFwsTEycfExMTExMTDxMTJxMTHw8TExMnFw8XCxcXCyMLFxMfFxMTExMTDx8TCxMfDxMLFx8LExMfKxMTEx8LExMfExMTHxcTEx8XExMfFx8LExMfKxMTHxcTEx8XExMfExMTHxMTEx8TEx8TEx8TEx8TExMfFx8TEx8zExMTHxMTHxMTEx8TExMfExMfExMfExMfExMfExMfExMfExMfBwVZWQkFXUkBLw8HHycHCw8rHwsnFx8jHwcnG0MvKx8fAlovHwVHAwMTLgMFSRV7XgUDAxMqAwVLFW4EJwVNBU8FUQMDmgRWCAVTBVUFVwVZAwUSAgoFFgIaAgVbBV0VV0oDFSIEJx0b2gQdogOmAwVfBWEFYwVlBWcdoykdo6kRAQUdow8DAxPuBAMDYgVaCB0NfgUdDYoFFU4C8gUdDZIHIw0HMQEAAAAAAAAAAAQAAAAAAACAAAAAAAAAACMNBSEABAAAAAAAAIAAAAAAAAAADScNKR0NMgMdQgNGAx1OA1IDHVoDXgMdZgNqAx1yA3YDBWkVV7IDBWsdDc4DHQ3mAx0NAgQdDS4EHQ1GBB0vWgQFbR0b/gQdGxoFHRsyBR1GBUoFEQsABW8dG3oHHRuGBx0bAggdGxoIHRsyCB0bSggFcR03NgMV2ScdfgOCAx03rgMFcx03ygMDA2W/HTfSAx034gMV4ScFdQMDZSoEFe0nFXEnHQ12BB2KBI4EBXcdDa4EAwMTTgUdDWYFHQ1yBR0NGgYVkgZJHQ3WBxENEWFmZmluZV9tYXA8KGQwLCBkMSwgZDIpIC0+IChkMCwgZDEsIGQyKT4AYWZmaW5lX21hcDwoZDAsIGQxKSAtPiAoZDAsIGQxKT4AEQ0BBXkFewV9BX8FgQWDBYUFhwWJHToDPgMdx4oDFdljHZfGAx3qA+4DFeFjHZf+Ax3pEgQFiwMDZcUdLzIEFe1jHZdCBAMDEz0d6V4EAwWvPfk9BY0jdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjdHB1LnBpcGVsaW5lX21vZGU8c3luY2hyb25vdXM+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5tZW1vcnlfc3BhY2U8c21lbT4AI3RwdS5kb3RfZGltZW5zaW9uX251bWJlcnM8WzFdLCBbMV0sIFswXSwgWzBdLCBbMCwgMCwgMSwgMF0sIFtdLCBbXT4AI3RwdS5kb3RfZGltZW5zaW9uX251bWJlcnM8WzFdLCBbMF0sIFswXSwgWzFdLCBbMCwgMCwgMSwgMV0sIFtdLCBbXT4ABY8FkSMBAQEdH/IEHR8OBR0fJgUdPgVCBR0xUgUdNgKeBQWTBZUFlwWZHTGuBR0xugUdB+4FHX8OBgMFrz35JgYDAxOGBh2uBrIGBZsFnSMNAxEBAAAAAAAAAB3SBtYGBZ8FoR0aBx4HHR8uBx02ArYHHR8OCB0fJggdHz4IAwWWAr8ZmgIFowWlAw+iAqYCHaoCrgKyArYCugK+AsUZx8ICxgIFpwEHAgL//w0lBakjDQcxCAAAAAAAAAACAAAAAAAAAAIAAAAAAAAABasRDQkFrQWvARHKAtIC2gLiAuoC8gL6AgIDAwUjzgIlTQnJAwUj1gIlTQnLAwUj3gIlTQnNAwUj5gIlTwnPAweN/SPuAiVPCdEDB439I/YCJU8J0wMHjf0j/gIlTwnVAwUjBgMlTQnXAwUdURnJAwUdURnLAwUdURnNAwUdUxnPAwUdUxnRAwUdUxnTAwUdUxnVAwUdURnXEQEBEQMBFY+RLQMH5ggtWwWxLQMJRg8jTg8LBbMtAwkyER3KEQsVWVYDBbUtAwmiEhPaEgcVW2IDBbctAwkiJBNiJAcVXW4DBbktAwmuJBfKJAsVX3oDBbstYQeBJ08Vk4YDBb0tYQeNI0EV244DLWEHZgJBvR2SA5YDBb8tYQf+AgkVHRWeAxUtqgMFwS0DB94IK1EVld0tAwfmCB9dFVm2AxVbugMVXb4DFV/CAxWT2xWZkS0DB+YIH2UVnZEtAwcCCSVRHRXaAxUt3gMVn90tAwcCCRdTFY+hBcMtAwkmDyMuDwsdFfYDFS36AxWV4xWZoRWdoR0VCgQVLQ4EFZ/jFRYEJx0vGgQtAwdKCxcjHRUpHS8mBC0DB0oLByURDQUVj6ctAwliC2l+CwcdFToEFS0+BBWV7xWZpxWdpx0VTgQVLVIEFZ/vHRWpLQMHmgwHLRViBCcdL2YELQMHugwXPR0VDx0vcgQtAwe6DAc/FXoEDx0RfgQtAwfCDBE1AwMThgQTCRAAAOAPBcUVkgQPHRGWBC0DB8YMM0EFxx1zogQVpgQPHRGqBC0DB8YMGXsVsgQPHRG2BC0DB8oMJUkdMb4EFcIEDx0RxgQtAwfKDCVZHRXOBBXSBA8dEdYELQMHygwjgxXeBA8dEeIELQMHygwJHQMFEgLqBBYCGgIjAQkhAQAAAAEAAAADAAAAAAAAABMJARX2BA8dEfoELQMH5gwzbRUCBQ8dEQYFLQMH5gwJLSMBCSEBAAAAAQAAAAIAAAAAAAAAFRIFDx0RFgUtAwfqDDNtFR4FDx0RIgUtAwfqDAktFSoFDx0RLgUtAwfuDDNtFTYFDx0ROgUtAwfuDAktBckVe6kFyy0DB7IMCXERAQIgFVYFCR0HWgUtAweKCyllFXFjBc0VagUJHQduBS0DB44LK08VdgUJHQd6BS0DB44LU3cVggUJHQeGBS0DB54LESUVjgUJHQeSBS0DB6oLFTcDAxOaBSURCQAAAAAVogUJHQemBS0DB7YLE48DBzoCCgI+An1CAn0VsgUJHQe2BS0DB/4LI00VvgUJHQfCBS0DB/4LU48df8oFFc4FCR0H0gUtAwf+CyOPAwOvPR3eBeIFBc8V5gVJHTPqBS0DCaoJPbIJDy0DBxIME0UVe/YFFXH6BRVX/gUVWQIGFVsGBhVdCgYVX5MVEgZJHTMWBi0DCaoJJ7IJDxUeBkkdMyIGLQMHxgkVOxEBIR1zLgYVMgZJHTM2Bi0DCcIJJ8oJDwMDZT4GEQ0VHUYGSgYF0RVOBl4GHVIGVgYF0y1aBgcCBR89BdUVYgZqBh0zZgYtAwf2CSlzFU4CbgYVe3IGFXF2BhVXegYVWX4GFVuCBhVdXxMJkMzMzD8djga7BdcdM5YGLQMHhgoTUR0fux2iBrsF2QMDE6oGJRcJAACA/wXbFbYGCR0HugYtAwcaDBs5AwViAl4IZgJqAh0fxgYVygYJHQfOBi0DBxoMG0sF3RXaBgkdB94GLQMHIgwbUR1z5gYV6gYJHQfuBi0DB0YMNYcdcgL2BhX6BgkdB/4GLQMHRgwrhx12AgYHFQoHCR0HDgctAwdGDBuJAwMTFgclFwkAAAAABd8VIgcJHQcmBy0DB1IMTXMDBWICYghmAmoCFTIHCR0HNgctAwdSDBudHXICPgcVQgcJHQdGBy0DB14MKUcddgJOBxVSBwkdB1YHLQMHXgwZSR0xXgcVYgcJHQdmBy0DB2IMLUkdf24HFXIHCR0HdgctAwdiDBtJFX4HCR0HggctAwdmDAktFYoHCR0HjgctAwdmDDFVFZYHCR0HmgctAwd2DBU3HRWiBxWmBwkdB6oHLQMHggwRMwMDE7IHJQUJAAAAABW6BwkdB74HLQMHhgwbYQMHOgIOAj4CfUICfR1zygcVzgcJHQfSBy0DB44MHXsV2gcJHQfeBy0DB5IMQ2MdMeYHFeoHCR0H7gctAweSDC9jHX/2BxX6BwkdB/4HLQMHkgwvdRUGCAkdBwoILQMHkgwJKRUSCCkdNRYILQMHUgszbRUeCCkdNSIILQMHUgsJLRUqCCkdNS4ILQMHVgszgxU2CCkdNToILQMHVgsJLRVCCCkdNUYILQMHWgszbRVOCCkdNVIILQMHWgsJLSNhcml0aC5mYXN0bWF0aDxub25lPgAjYXJpdGgub3ZlcmZsb3c8bm9uZT4AI3ZlY3Rvci5raW5kPG1heGltdW1mPgAjdmVjdG9yLmtpbmQ8YWRkPgABAgIDJwUCIAIECRcGAgcFCQkTwQsBCQECBBf7BwUCIAIEH8EnBQIgAiAJAUEX+wUCIAIECcMnAwIgCScFAiACBB8nBwUCIAIEHycFAiACIAEHF/sFAiACBAHDJwUCIAUJBRsBAQEHBw8PDyEVFRUPAQULAQEBBwcHAQEBBQsBAQEHBwUBAScFAiACBAEnBQIgAiALBEIdBQERAZICBwMBJQ0RAZ4CBwNNdxsBAQEBAQEHAQcBDwEPAQ8BIQEVARUBFQEPAQMD5wsDAREH5+sDCwUFGx0GHgQDAQMdAwM5CwMBEQc5pQMLBR8hHxQ5AyMJAx9NAwOGAkEDCQkGhgIDBQNNAwOHBQMDAwOHBQMDBQaHAwUHF1FTCwWHIQlPF1FTAwOKAloCAwkJBooCAwUDVwMDiQUDAwMDiQUDAwUGiQMFBxNbXQsFiSEJWRNbXQMDjgJBAwkJBo4CAwUDYQMDiwUDAwMDiwUDAwUGiwMFBxVlZwsFiyEJYxVlZxkAOQMBBRkAOQMDbQUDAwcGbQMDAwMHBm0DAwMFFQZtAxMJCSUnKRcGNgQDAQMrAwPxCwMBEQfxmwMLBS0vAwNvBQMDBwZvAwMDAwcGbwMDAwUVBm8DEwkHMzU3FwZKBAMBAzkdBlYEAwEDMQMDOwsDAREHO6UDCwU9Px8UOwNBCQObigIDAyoCCwMBAwMuArMDASMHLgJDAwEFTU8DA7UFAwMDA7UFAwMFBrUDBQcTU1UDA7cFAwMDA7cFAwMFBrcDBQcVWVsDA0UFAwMDA0UFAwMDA0UFAwMFBkUDGwkLX2FjEwZFAxkDZQMDRwUDAwcGRwMDA1EDA0cFAwMFBkcDGwkNaWttEwZHAxkDbwMDMgKWBQMRJQcyAqoFAxEHZ3FzAwNGArMDASMHRgJDAwEFO3cDA0oCswMBIwdKAkMDAQVNeycHxgVDAwEFeX01A9oF1gUDHQkGUgIDHQN/JwdSAkMDHQWDgQMDuQUDAwMDuQUDAwUGuQMrBxGHiRsHKgZWAgMdA4sRB0IGOgYDLQWNhQMDigZaAgMJCQaaBgMRA5E3Bp4GAxEHj3WTAwNeAqYGAxcpB14CvgYDFwWVlxMGwgYDIwOZCQZuAgMFA5s5B24CFwMFBVedGwfiBlYCAxEDnysH8gYXAxEFlaEtBwIHFwMRA6MDA3oCEgcDFykHegIqBwMXBaWnEwZ+AgMjA6kJBn4CAwUDqysHOgcXAwUFV58tB0oHFwMFA68hB1oHFwMFBbFdLwdqBxcDBQWtswMDgQUDAwMDgQUDAwUGgQMFBxO3uQsFgSEJnxO3uQMDgwUDAwMDgwUDAwUGgwMFBxW9vwsFgyEJtRW9vwMDSwUDAwcGSwMDA1EDA0sFAwMFBksDGwkPw8XHEwZLAxkDyTsGngcDBQPLAwOCAq4HAwUlB4ICwgcDBQelzc8bB8YH9wMFA7EDA70FAwMDA70FAwMFBr0DBQcX1dchB+IHFwMFBdPZLwfyBxcDBQXb0QMDhQUDAwMDhQUDAwUGhQMFBxff4QsFhSEJ3Rff4QMDKgLzAwEZADsDAQUZADsDA/XzAwERB/XrAwsFBUMdBmoEAwEDRQMDPwsDAREHP6UDCwVHSR8UPwNLCQNDmQMDqwUDAwMDqwUDAwUGqwMFBxVNTwMDrYIEAwkJBq0DBQNTMQetFwMFBVVRGweeBPcDBQNXAwOxBQMDAwOxBQMDBQaxAwUHF1tdIQe6BBcDBQVfWTMGygQDGQNhAwMrBQMDAwMrBQMDAwMrBQMDBQYrAxsJGWVnaRMGKwMZA2sTBisDGwNjCwUr5gQLbxllZ2kDAx4CQQMJCQYeAgMFA3EDA3UFAwMDA3UFAwMFBnUDBQcTdXcLBXUhCXMTdXcDAyICQQMJCQYiAgMFA3sDA3cFAwMDA3cFAwMFBncDBQcVf4ELBXchCX0Vf4EDAyYCQQMJCQYmAgMFA4UDA3kFAwMDA3kFAwMFBnkDBQcXiYsLBXkhCYcXiYsZAD8DAQUZAD8PAAENEQEKAwcDDw8LAQEBAQEBBwEHAQMDAQsDAQMDAQsDAQ8EAQcBAwsNEQEOAwcDJz8LAQEBAQEBBwEHAQMDaQUDAwcGaQMDAwMHBmkDAwMFFQZpAxMJCQsNDxcG8gMDAQMRAwPlCwMBEQflmwMLBRMVAwNrBQMDBwZrAwMDAwcGawMDAwUVBmsDEwkHGRsdFwYGBAMBAx8DAwELAwEDAwELAwEPBAEHASEjDREBEgMHAyc/CwEBAQEBAQcBBwEDA1UFAwMHBlUDAwMDBwZVAwMDBRUGVQMTCQkLDQ8XBpoDAwEDEQMD3wsDAREH35sDCwUTFQMDZwUDAwcGZwMDAwMHBmcDAwMFFQZnAxMJBxkbHRcG1gMDAQMfAwMBCwMBAwMBCwMBDwQBBwEhIw0RARYDBwMPDwsBAQEBAQEHAQcBAwMBCwMBAwMBCwMBDwQBBQMLDREBGgMHAxETCwEBAQEBAQcBBwEDAwELAwEDAwELAwEDAwELAwEPBAEFCw0NEQEeAwcDERMLAQEBAQEBBwEHAQMDAQsDAQMDAQsDAQMDAQsDAQ8EAQULDQ0RASIDBwMREwsBAQEBAQEHAQcBAwMBCwMBAwMBCwMBAwMBCwMBDwQBBQsNDREBJgMHAw8PCwEBAQEBAQcBBwEDAwELAwEDAwELAwEPBAEHAQMLBgMBBQEA6h/hGQsZFQ0qAmUJDR1JDRMLX0ETI2U/JTM1Xx0jISMpMS0LCx8LHR0lGxEpDQkZGRkZGRkZGQsVDQkdCxEVdx1LMwsvHSUlHQ0TLQ1JC0syAhcfGxMbFxcTFy8XFxcXDxkXFRkZJRcZFSMjIxkfDw8NCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAGFyaXRoLmNvbnN0YW50AHZlY3Rvci5sb2FkAGFyaXRoLmluZGV4X2Nhc3QAdmVjdG9yLmJyb2FkY2FzdAB0cHUudmVjdG9yX3N0b3JlAGZ1bmMuZnVuYwBmdW5jLnJldHVybgBhcml0aC5jbXBpAHZlY3Rvci5zaGFwZV9jYXN0AG1lbXJlZi5sb2FkAGFyaXRoLmV4dHNpAHNjZi55aWVsZAB0cHUucmVwZWF0AGFyaXRoLmV4dHVpAHNjZi5pZgBhcml0aC5tdWxmAGFyaXRoLm11bGkAdHB1Lm1hdG11bABhcml0aC5hZGRpAHZlY3Rvci5tdWx0aV9yZWR1Y3Rpb24AYXJpdGguc3ViZgBtYXRoLmV4cABhcml0aC5hZGRmAGFyaXRoLmRpdmYAYXJpdGgudHJ1bmNmAHRwdS5pb3RhAGFyaXRoLnNlbGVjdABhcml0aC5tYXhpbXVtZgBhcml0aC5leHRmAC9ob21lL3B0b3VsbWUvbWluaWNvbmRhMy9lbnZzL3ZsbG0vbGliL3B5dGhvbjMuMTIvc2l0ZS1wYWNrYWdlcy9qYXgvZXhwZXJpbWVudGFsL3BhbGxhcy9vcHMvdHB1L3NwbGFzaF9hdHRlbnRpb24vc3BsYXNoX2F0dGVudGlvbl9rZXJuZWwucHkAZmxhc2hfYXR0ZW50aW9uX2tlcm5lbC48bG9jYWxzPi5ib2R5AC9nZXQAZmxhc2hfYXR0ZW50aW9uX2tlcm5lbC48bG9jYWxzPi5lbmQAdmFsdWUAL2NvbnZlcnRfZWxlbWVudF90eXBlAHN5bV9uYW1lAC9zd2FwAGZ1bmN0aW9uX3R5cGUAL2Jyb2FkY2FzdF9pbl9kaW0AdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwBmbGFzaF9hdHRlbnRpb25fa2VybmVsAC9tdWwAX2FwcGx5X21hc2tfYW5kX3NvZnRfY2FwAGZsYXNoX2F0dGVudGlvbl9rZXJuZWwuPGxvY2Fscz4uaW5pdABfbmV4dF9ub256ZXJvAC9ob21lL3B0b3VsbWUvanVzdGFieXRlL3RwdV9wYWxsYXNfcG9zdC8uL3BhbGxhc19rZXJuZWwucHkAcHJlZGljYXRlAC9yZXBlYXQAL2FkZABwaXBlbGluZV9tb2RlAC9ndAAvY29uZABkaW1lbnNpb24AbWFpbgB0cmFuc2Zvcm1fMAB0cmFuc2Zvcm1fMQB0cmFuc2Zvcm1fMgB0cmFuc2Zvcm1fMwB0cmFuc2Zvcm1fNAB0cmFuc2Zvcm1fNQB0cmFuc2Zvcm1fNgB0cmFuc2Zvcm1fNwAvZXEAdGltZXMAb3BlcmFuZFNlZ21lbnRTaXplcwBzdHJpZGVzAC9kb3RfZ2VuZXJhbABkaW1lbnNpb25fbnVtYmVycwB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvc3ViAC9leHAAc3RhYmxlX21vc2FpYy52ZXJzaW9uAHNwbGFzaF9taGFfZndkX2M2MDJiMTZkAGRpbWVuc2lvbl9zZW1hbnRpY3MAaXRlcmF0aW9uX2JvdW5kcwBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwB3aW5kb3dfcGFyYW1zAF9zcGxhc2hfYXR0ZW50aW9uX2ZvcndhcmQuPGxvY2Fscz4udl9pbmRleF9tYXAAX3NwbGFzaF9hdHRlbnRpb25fZm9yd2FyZABfc3BsYXNoX2F0dGVudGlvbl9jdXN0b20AX3NwbGFzaF9hdHRlbnRpb24AU3BsYXNoQXR0ZW50aW9uS2VybmVsLl9fY2FsbF9fAGJlbmNobWFya19rZXJuZWwuPGxvY2Fscz4uc3BsYXNoX2F0dGVudGlvbl9rZXJuZWwAYmVuY2htYXJrX2tlcm5lbAA8bW9kdWxlPgBfbmV4dF9ub256ZXJvLjxsb2NhbHM+LjxsYW1iZGE+AF9zcGxhc2hfYXR0ZW50aW9uX2ZvcndhcmQuPGxvY2Fscz4ua19pbmRleF9tYXAAL2RpdgBmYXN0bWF0aAAvc2NhbgBmbGFzaF9hdHRlbnRpb25fa2VybmVsLjxsb2NhbHM+LnJ1bgBvdmVyZmxvd0ZsYWdzAC9pb3RhAC9nZQBDYXVzYWxNYXNrLl9faW5pdF9fLjxsb2NhbHM+LmNhdXNhbF9tYXNrX2Z1bmN0aW9uAC9ob21lL3B0b3VsbWUvbWluaWNvbmRhMy9lbnZzL3ZsbG0vbGliL3B5dGhvbjMuMTIvc2l0ZS1wYWNrYWdlcy9qYXgvZXhwZXJpbWVudGFsL3BhbGxhcy9vcHMvdHB1L3NwbGFzaF9hdHRlbnRpb24vc3BsYXNoX2F0dGVudGlvbl9tYXNrLnB5AC9waml0AC9zZWxlY3RfbgAvcmVkdWNlX21heAAvbWF4AC9yZWR1Y2Vfc3VtAA==", "serialization_format": 1, "needs_layout_passes": true}}
ROOT %get-tuple-element.0 = bf16[8,2048,128]{2,1,0} get-tuple-element(%custom-call.0), index=3, metadata={op_name="jit(splash_attention_kernel)/jit(main)/splash_attention_kernel/splash_kernel_b1024_h8_s2048/jit(_splash_attention)/splash_mha_fwd_c602b16d/pallas_call" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=1100}
}
When XLA Isn't Enough: From Pallas to VLIW with Splash Attention on TPU
When does XLA hit its limits? How do you write the TPU Pallas kernel that the compiler cannot automatically find? Why can't XLA generate Splash Attention?
Full Code / IR Dumps: https://github.com/patrick-toulme/justabyte/tree/main/tpu_pallas_post
Connect On Linkedin: https://www.linkedin.com/in/patrick-toulme-150b041a5/
Follow on X: https://x.com/PatrickToulme
Motivation
Last post, the TPU compiler did everything: 8 lines of JAX became 250 VLIW bundles across 5 fused kernels, with automatic memory orchestration, dual-MXU scheduling, and async DMA overlap. The thesis was simple—write high-level code, let the compiler figure out the rest.
This post is about where that breaks down.
Same hardware, same compiler: naive attention compiles to 10,245 VLIW bundles and moves 526MB through HBM. Splash Attention (block_size=512) compiles to 1,651 bundles and moves 14MB. XLA optimized both aggressively. The 6× gap isn’t a compiler failure—it’s a compiler limitation. XLA can fuse ops and schedule instructions, but it can’t rewrite your algorithm. It can’t infer the streaming/online-softmax reformulation, that you never need to materialize the full attention matrix, that there’s a numerically-stable streaming formulation hiding inside your einsum.
In principle, a compiler could learn this. But no one has taught one yet. FlashAttention exists because a human figured out the trick and wrote a kernel. Until compilers get smarter, we need an escape hatch.
Pallas is that escape hatch. This post traces Splash Attention—JAX’s FlashAttention for TPUs—through the TPU Pallas Compiler, and compares it to naive attention at every layer of IR.
Setup:
All experiments were performed on a TPU V6e Trillium from Google Cloud. We use the same dump flags from the previous post with the addition of xla_mosaic flags which dumps the Mosaic compiler IR for Pallas.
os.environ["XLA_FLAGS"] = (
f"--xla_dump_hlo_as_text "
f"--xla_dump_to={HLO_PATH} "
f"--xla_dump_hlo_pass_re=.* "
)
os.environ["LIBTPU_INIT_ARGS"] = (
f"--xla_jf_dump_to={LLO_PATH} "
f"--xla_jf_dump_hlo_text=true "
f"--xla_jf_dump_llo_text=true "
f"--xla_jf_dump_llo_html=false "
f"--xla_jf_dump_llo_static_gaps=true "
f"--xla_jf_emit_annotations=true "
f"--xla_jf_debug_level=2 "
f"--xla_mosaic_dump_to={MOSAIC_PATH} "
f"--xla_mosaic_enable_dump_debug_info=true "
f"--xla_mosaic_enable_llo_source_annotations=true"
)Full Code / IR Dumps: https://github.com/patrick-toulme/justabyte/tree/main/tpu_pallas_post
What this post covers:
The Compiler’s Limitation: Why XLA can’t fuse across the attention matrix
Splash Attention Kernel: How Pallas expresses online softmax with grid, BlockSpec, and scratch
Pallas Compiler Pipeline: The path from Pallas code → Mosaic → LLO → VLIW
HLO/Mosaic/LLO Walkthrough: What each IR layer reveals
Reference Attention Comparison: Same algorithm through XLA’s standard path
Analysis: Bundle counts, HBM traffic, block size tradeoffs
The Compiler’s Limitation
Why can’t XLA automatically achieve Pallas-level performance for every case? Looking at what XLA does fuse for reference JAX attention:
fusion.5: Q @ K^T + mask + reduce_max → (max, scores)
fusion.2: exp(scores - max) + reduce_sum → sum
fusion: normalize + S @ V → outputXLA fuses the matmul with the mask and max reduction. It fuses the exp with the sum reduction. It’s doing real work here. But the 128MB attention matrix still gets written to HBM between fusion.5 and fusion.2.
The problem is that standard softmax is inherently multi-pass: you need the max over all values before you can compute any exp. XLA can fuse along dataflow edges, but it can’t restructure the algorithm. The attention matrix has to exist somewhere because multiple operations need to read it.
Online softmax sidesteps this by maintaining running statistics — you compute a tile’s contribution to max/sum/output, then update your estimates as you process the next tile. This is an algorithmic transformation, not a fusion pattern. XLA would need to:
Recognize that the softmax + matmul pattern permits streaming
Prove the online reformulation is numerically equivalent
Restructure the computation to process tiles incrementally
Concretely: it would have to replace scores = QKᵀ; p = softmax(scores); o = pV with a loop that carries (m, l, o) and updates them per KV tile.
No production compiler does this today (to my knowledge). Pallas lets you write the streaming algorithm directly.
Splash Attention Kernel
Splash Attention is JAX’s FlashAttention for TPUs, built on Pallas. The algorithm itself is well-documented elsewhere — here we focus on how Pallas expresses it.
// Standard softmax (requires full matrix):
max = reduce_max(scores) // Need ALL scores first
exp_scores = exp(scores - max) // Then process all
sum = reduce_sum(exp_scores) // Need ALL exp values
output = exp_scores / sum
// Online softmax (streaming):
for each KV_block:
m_new = max(m_prev, local_max) // Update running max
correction = exp(m_prev - m_new) // Rescale factor
l_new = correction * l_prev + local_sum // Update running sum
o_new = correction * o_prev + local_out // Update running outputPallas Kernel Structure
A Pallas kernel is a function that operates on Refs — mutable views into memory regions. The runtime calls your kernel once per grid point, with Refs pointing to the appropriate tiles:
def flash_attention_kernel(
# Prefetched scalar inputs (live in SMEM)
data_next_ref, # Which KV block to load
block_mask_ref, # Skip mask
mask_next_ref, # Partial mask index
# Tiled inputs (sliced per grid point)
q_ref, # [block_q, head_dim] slice of Q
k_ref, # [block_kv, head_dim] slice of K
v_ref, # [block_kv, head_dim] slice of V
...
# Scratch space (persists across grid iterations)
m_scratch_ref, # [block_q, 128] running max
l_scratch_ref, # [block_q, 128] running sum
o_scratch_ref, # [block_q, head_dim] running output
# Output tile
o_ref,
*,
# Static config (compiled into kernel)
bq: int,
bkv: int,
...
):Inside the kernel, you read/write Refs with ref[...] syntax. Pallas traces these operations and compiles them to Mosaic IR.
BlockSpec: Mapping Grid to Memory
BlockSpec defines how each grid point maps to a tile of the input/output arrays:
pl.BlockSpec(
block_shape=(None, bq, head_dim), # None = full dim, bq = tiled
index_map=lambda h, i, j, *_: (h, i, 0) # grid coords → tile origin
)
For Q, the index map is simple — grid point (h, i, j) reads Q block (h, i, :). But K and V need indirection for sparse attention:
def k_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref):
# Look up which KV block to actually load (enables block skipping)
next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref)
return (h // q_heads_per_kv_head, next_j, 0)
The data_next_ref indirection is how Splash skips masked blocks — if block j is fully masked, data_next[h,i,j] points to the next valid block instead.
Grid and Dimension Semantics
grid = (num_q_heads, q_seq_len // block_q, grid_width)
compiler_params = pltpu.CompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary")
)
Dimension 0 (heads):
"parallel"— no cross-iteration dependencies, can execute in any orderDimension 1 (Q blocks):
"arbitrary"— independent because scratch is per-(head, q_block)Dimension 2 (KV blocks):
"arbitrary"— dependent because scratch accumulates across KV tiles
The "parallel" hint lets the compiler vectorize across heads. "arbitrary" means “don’t assume anything” — safe but conservative.
PrefetchScalarGridSpec
Splash uses PrefetchScalarGridSpec to overlap data loading with compute:
pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=3, # First 3 inputs go to SMEM, prefetched ahead
in_specs=[...],
out_specs=[...],
grid=grid,
)The first 3 arguments (data_next, block_mask, mask_next) are small index arrays. Putting them in SMEM with prefetch means the kernel can look up the next block address while the current block is still computing.
Memory Spaces
Pallas exposes TPU memory hierarchy:
# VMEM (on-chip SRAM) — default for BlockSpec refs
pl.BlockSpec((bq, head_dim), index_map)
# SMEM (scalar memory) — for small indexing data
pl.BlockSpec((num_heads,), lambda *_: (0,), memory_space=pltpu.SMEM)
# Scratch (persists across grid iterations)
jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32) # m_scratch, l_scratchThe scratch refs (m_scratch_ref, l_scratch_ref, o_scratch_ref) are the key to online softmax — they accumulate across the KV dimension without spilling to HBM.
Pallas Compiler Pipeline
When building a DSL compiler, you want to reuse as much of the existing compiler infrastructure as possible rather than writing everything from scratch. Pallas shows up in HLO as an opaque tpu_custom_call. XLA still does its usual whole-graph work (layout, scheduling, surrounding fusions), but it can’t see or rewrite the kernel body. The payload is serialized MLIR that goes straight to Mosaic → LLO → VLIW. The key insight is that both XLA's fusion region codegen and Pallas share the same LLO backend. Whether the TPU is running a compiler-generated fusion or a hand-written Pallas kernel, the final stages—MXU scheduling, DMA overlap, VLIW packing—are identical. Pallas just enters the pipeline at a different point.
What XLA Sees
The whole Pallas kernel becomes a single opaque HLO op:
%custom-call = bf16[8,2048,128] custom-call(
bf16[8,2048,128] %q, bf16[8,2048,128] %k, bf16[8,2048,128] %v, ...
), custom_call_target="tpu_custom_call", backend_config="..."
XLA’s fusion passes can’t look inside — the kernel goes directly to Mosaic, which compiles it to LLO → VLIW bundles. The backend_config encodes the full Pallas program as serialized MLIR IR.
HLO
Here’s what the Pallas kernel looks like in HLO:
The key insight: XLA sees the entire Splash Attention kernel as one opaque custom-call. The backend_config contains the entire kernel body as base64-encoded MLIR (that "TUzvUg..." blob). XLA can’t optimize inside it — it just hands the blob to Mosaic for TPU-specific compilation.
Notice the scratch buffers in the return tuple: f32[1024,128] appears three times. These are the online softmax accumulators (m_scratch, l_scratch, o_scratch). The fourth element — bf16[8,2048,128] — is the actual output.
The s8[1,2,2] block masks encode which blocks to process. With seq_len=2048 and block_size=1024, we have 2048/1024 = 2 blocks per dimension, hence the 2×2 shape.
Mosaic Compilation Pipeline
When the HLO custom-call with custom_call_target="tpu_custom_call" reaches the TPU backend, the base64-encoded kernel body gets deserialized and fed to Mosaic. The compilation proceeds through 12 passes:
Stage 1: Original MLIR
The kernel arrives as high-level MLIR with explicit memory spaces and TPU-specific operations:
func.func @main(
%arg0: i32, // grid index: head
%arg1: i32, // grid index: q_block
%arg2: i32, // grid index: kv_block
%arg3: memref<1x2x2xi8, #tpu.memory_space<smem>>, // block_mask
%arg4: memref<1x2x2xi8, #tpu.memory_space<smem>>, // data_next
%arg5: memref<1x1024x128xbf16, #tpu.memory_space<vmem>>, // Q block
%arg6: memref<1x1024x128xbf16, #tpu.memory_space<vmem>>, // K block
%arg7: memref<1x1024x128xbf16, #tpu.memory_space<vmem>>, // V block
%arg8: memref<1024x128xi32, #tpu.memory_space<vmem>>, // q_sequence (for causal)
%arg9: memref<1024x128xf32, #tpu.memory_space<vmem>>, // m_scratch
%arg10: memref<1024x128xf32, #tpu.memory_space<vmem>>, // l_scratch
%arg11: memref<1024x128xf32, #tpu.memory_space<vmem>>, // o_scratch
%arg12: memref<1x1024x128xbf16, #tpu.memory_space<vmem>> // output
) attributes {
dimension_semantics = [#tpu.dimension_semantics<parallel>, // heads
#tpu.dimension_semantics<arbitrary>, // q_blocks
#tpu.dimension_semantics<arbitrary>], // kv_blocks
iteration_bounds = array<i64: 8, 2, 2>, // 8 heads, 2 q_blocks, 2 kv_blocks
scalar_prefetch = 2 : i64, // prefetch first 2 args (mask info)
window_params = [...]
}
The function attributes encode Pallas’s grid specification. iteration_bounds = [8, 2, 2] means the kernel runs 8 × 2 × 2 = 32 iterations. The first dimension is parallel (heads can run independently), while Q and KV blocks are arbitrary (have data dependencies through scratch buffers).
The kernel body shows the attention computation:
// Load Q block, K block
%q = vector.load %arg5[%c0, %c0, %c0] : vector<1x1024x128xbf16>
%k = vector.load %arg6[%c0, %kv_idx, %c0] : vector<1x1024x128xbf16>
// QK^T matmul -> [1024, 1024] attention scores
%qk = tpu.matmul %21, %24, %cst {
dimension_numbers = #tpu.dot_dimension_numbers<[1], [1], [0], [0], [0, 0, 1, 0], [], []>
} : vector<1024x128xbf16>, vector<1024x128xbf16>, vector<1024x1024xf32>
// Apply causal mask using iota comparison
%29 = tpu.iota {dimension = 1 : i32} : vector<1024x1024xi32>
%34 = arith.cmpi sge, %33, %31 : vector<1024x1024xi32>
%36 = arith.select %34, %25, %35 : vector<1024x1024xi1>, vector<1024x1024xf32>
// Online softmax: max reduction
%37 = vector.multi_reduction <maximumf>, %36, %cst_20 [1] : vector<1024x1024xf32> to vector<1024xf32>
%40 = arith.maximumf %18, %39 : vector<1024x128xf32> // update running max
// exp(scores - max) and sum reduction
%42 = arith.subf %36, %41 : vector<1024x1024xf32>
%43 = math.exp %42 : vector<1024x1024xf32>
%44 = vector.multi_reduction <add>, %43, %cst_21 [1] : vector<1024x1024xf32> to vector<1024xf32>
// Rescale previous accumulator: alpha = exp(m_prev - m_next)
%47 = arith.subf %18, %40 : vector<1024x128xf32>
%48 = math.exp %47 : vector<1024x128xf32>
%49 = arith.mulf %48, %19 : vector<1024x128xf32> // alpha * l_prev
%50 = arith.addf %46, %49 : vector<1024x128xf32> // l_next = l_curr + alpha * l_prev
// S @ V matmul for current output contribution
%57 = tpu.matmul %43, %56, %cst_28 {
dimension_numbers = #tpu.dot_dimension_numbers<[1], [0], [0], [1], [0, 0, 1, 1], [], []>
} : vector<1024x1024xf32>, vector<1024x128xf32>, vector<1024x128xf32>
// Update output accumulator: o_next = alpha * o_prev + o_curr
%60 = arith.mulf %58, %59 : vector<1024x128xf32> // alpha * o_prev
%61 = arith.addf %60, %57 : vector<1024x128xf32> // o_next
// Store updated accumulators
tpu.vector_store %arg9[...], %40 // m_scratch
tpu.vector_store %arg10[...], %50 // l_scratch
tpu.vector_store %arg11[...], %61 // o_scratch
Stage 2-3: Layout and Tiling
The infer-vector-layout and tiling passes map abstract vectors onto the VPU’s physical structure (8 sublanes × 128 lanes):
// Before: abstract vector and memref
%qk = tpu.matmul %q, %k, %zeros : vector<1024x1024xf32>
%arg5: memref<1x1024x128xbf16, #tpu.memory_space<vmem>>
// After: explicit layouts matching hardware
%qk = tpu.matmul %q, %k, %zeros {
in_layout = [#tpu.vpad<"16,{0,0},(16,128)">, // bf16: 16 sublanes
#tpu.vpad<"16,{0,0},(16,128)">],
out_layout = [#tpu.vpad<"32,{0,0},(8,128)">] // f32: 8 sublanes
} : ...
%arg5: memref<1x1024x128xbf16,
#tpu.tiled<(8,128)(2,1)>, // 8×128 tiles, 2 bf16 per 32-bit word
#tpu.memory_space<vmem>>The (8,128) tile shape matches the VPU dimensions. When layouts don’t match between operations, the compiler inserts explicit relayouts.
Stage 4: Lower to LLO
The lower-to-llo pass transforms everything to Low-Level Operations — the final MLIR dialect before machine code:
// High-level vector store
tpu.vector_store %arg11[%c0, %c0], %result : memref<1024x128xf32>, vector<1024x128xf32>
// Becomes LLO with explicit addressing
%addr = "llo.saddr_scaled"(%arg11, %offset) <{multiplier_in_bytes = 512}>
llo.vector_store %result into %addr : vector<8x128xf32> into i32
llo.vector_store %result into %addr + %8 : vector<8x128xf32> into i32
// ... (128 stores for 1024 rows at 8 rows per tile)
The 1024×128 vector gets broken into 128 tiles of 8×128, each stored with computed VMEM addresses.
Pallas LLO Compilation: From Mosaic to Machine Code
Next, the Mosaic IR is converted to native TPU LLO IR. We will now trace the computation through the LLO compiler.
For this LLO walkthrough, we use block_size=512 (grid 8×4×4) to show more loop iterations and better demonstrate software pipelining. The HLO and Mosaic sections above use block_size=1024.
Kernel Structure (Pass 02: Original)
The initial LLO reveals the kernel's memory layout and loop structure:
$region0: #{splash_mha_fwd_e9701070.1}
// Memory allocations
#allocation1 [shape = 'u32[144,128]', space=vmem, size = 0x12000, tag = 'internal scratch']
#allocation3 [shape = 'u8[512]', space=smem, size = 0x200, tag = 'prefetched SMEM operand 0']
#allocation4 [shape = 'u8[512]', space=smem, size = 0x200, tag = 'prefetched SMEM operand 1']
// Kernel inputs/outputs
%s0 = inlined_call_operand.hbm [shape: s8[1,4,4], index: 0] // block_mask
%s1 = inlined_call_operand.hbm [shape: s8[1,4,4], index: 1] // data_next
%s2 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 2] // Q
%s3 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 3] // K
%s4 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 4] // V
%s5 = inlined_call_operand.vmem [shape: s32[2048,128], index: 5] // q_sequence (causal)
// Scratch buffers for online softmax
%s6 = inlined_call_operand.hbm [shape: f32[512,128], index: 6] // m_scratch output
%s7 = inlined_call_operand.hbm [shape: f32[512,128], index: 7] // l_scratch output
%s8 = inlined_call_operand.hbm [shape: f32[512,128], index: 8] // o_scratch output
%s9 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 9] // final output
// Prefetch mask data to SMEM before main loop
%16 = dma.hbm_to_smem /*hbm=*/%s0, /*size=*/16, /*smem=*/[#allocation3]
%18 = dma.hbm_to_smem /*hbm=*/%s1, /*size=*/16, /*smem=*/[#allocation4]
%19 = dma.done [#allocation2], 32
The kernel operand layout reflects Pallas’s PrefetchScalarGridSpec:
Operands 0-1: Block mask indices prefetched to SMEM
Operands 2-4: Q, K, V matrices streaming from HBM
Operand 5: Causal sequence indices (pre-loaded to VMEM)
Operands 6-8: Online softmax scratch buffers
Loop Structure with Phi Nodes
With block_size=512 and seq_len=2048, the grid is (8 heads, 4 q_blocks, 4 kv_blocks). Software pipelining adds 2 iterations for warmup/cooldown:
loop: start=0, step=1, limit=130 // 8×4×4 = 128 + 2 for pipeline
$region107: #{splash_mha_fwd_e9701070.1} parent=1 // loop_body
// Iteration tracking across 3 dimensions
%s6473 = sphi 0, %s37 /* iteration index, stage = 0 */
%s6475 = sphi 0, %s59 /* iter bound = 0 (head index) */
%s6477 = sphi 0, %s55 /* iter bound = 1 (q_block index) */
%s6479 = sphi 0, %s51 /* iter bound = 2 (kv_block index) */
// Pipelined stage tracking (stage 1 lags stage 0)
%s6481 = sphi 0, %s6475 /* stage = 1 iter bound = 0 */
%s6483 = sphi 0, %s6477 /* stage = 1 iter bound = 1 */
%s6485 = sphi 0, %s6479 /* stage = 1 iter bound = 2 */
// Online softmax state (carries across KV blocks)
%s6487 = sphi 0, %s66 /* running max state */
%s6493 = sphi 0, %s128 /* running sum state */
%s6499 = sphi 0, %s190 /* running output state */
// Dimension wrap-around logic
%s49 = sadd.s32 1, %s6479 /* kv_block + 1 */
%p50 = scmp.ge.s32.totalorder %s49, 4 /* kv_block >= 4? */
%s51 = scalar_select %p50, 0, %s49 /* wrap to 0 */
%s52 = sadd.s32 1, %s6477 /* q_block + 1 (if kv wrapped) */
%s53 = scalar_select %p50, %s52, %s6477 /* conditional increment */
%p54 = scmp.ge.s32.totalorder %s53, 4 /* q_block >= 4? */
%s55 = scalar_select %p54, 0, %s53 /* wrap to 0 */
%s56 = sadd.s32 1, %s6475 /* head + 1 (if q wrapped) */
%p58 = scmp.ge.s32.totalorder %s57, 8 /* head >= 8? */
%s59 = scalar_select %p58, 0, %s57 /* wrap to 0 */
The phi nodes track iteration state across the flattened 3D grid. The stage = 1 variables implement software pipelining—while stage 0 loads the next tile, stage 1 processes the current tile.
DMA Operations with Double Buffering
The kernel uses double buffering for async DMA overlap:
// Compute buffer slot for double buffering
%s14824_s0 = sand.u32 1, %s8466_s28 /* slot = iteration % 2 */
%s344_s11 = scalar_lea.sflag [#allocation6], %s15337_s19 /* sync flag */
// Compute HBM source address
%s6544_s15 = sshll.u32 %s353_s16, 6 /* address offset */
%s355_s3 = scalar_lea.hbm %s15333_s2, %s6544_s15
// Compute VMEM destination
%s347_s24 = scalar_lea.vmem [#allocation5], %s6541_s30
%s358_s1 = int_to_ptr.vmem %s357_s1
// Issue async DMA: HBM → VMEM
%7188 = dma.hbm_to_vmem [thread:$0] (!%p8805_p8),
/*hbm=*/%s355_s3,
/*size_in_granules=*/4096, /* 256KB for 512×128 bf16 */
/*vmem=*/%s358_s1,
/*dst_syncflagno=*/%s344_s11
/* metadata:
window_bounds: (1, 64, 1) // tile shape
iteration_bounds: (8, 4, 4) // grid dimensions
element_size: 2048 bytes */
// Wait for previous DMA before consuming buffer
%8381 = dma.done.wait (%p15353_p5), %s454_s2, 4096
The (!%p8805_p8) predicate enables block skipping—if the current tile is fully masked, the DMA is skipped entirely.
Matmul Operations (Pass 13: Post-MXU-Assigner)
After MXU assignment, all matmuls are initially on mxu0:
// Q @ K^T for attention scores (line 749 in splash_attention_kernel.py)
%1237 = vmatprep.subr.mxu0 0
%1238 = vmatpush1.bf16.xpose.msra.mxu0 %v1164
%1269 = vmatprep.mubr.bf16.mxu0 0
%1270 = vmatmul.mubr.bf16.gmra.mxu0 %v1047
// S @ V for output accumulation (line 801)
%1279 = vmatprep.mubr.bf16.mxu0 0
%1280 = vmatmul.mubr.bf16.gmra.mxu0 %v1050
The operations decode as:
vmatprep.subr: Prepare RHS with subtraction (accumulator init)vmatpush1.bf16.xpose: Push bf16 tile to MXU, transposedvmatmul.mubr.bf16.gmra: Multiply-accumulate with bf16 inputs
Online Softmax in LLO
The online softmax pattern appears in three stages:
1. Max Reduction Tree:
// Tree reduction: max across 4 tiles
%v2530 = vmax.f32 %v9322, %v9326 /* max(tile0, tile1) */
%v2531 = vmax.f32 %v2530, %v9324 /* max(result, tile2) */
%v2532 = vmax.f32 %v2531, %v9343 /* max(result, tile3) */
// Cross-lane reduction using XLU
%2533 = vmax.xlane.f32.xlu0 %v2532 /* reduce across 128 lanes */
%v2534 = vpop.xlane.xlu0 %2533 /* pop result */
2. Exponential via Base Conversion:
// exp(x) = 2^(x × log₂(e)) where log₂(e) = 1.442695
%v3170 = vmul.f32 1.442695, %v2914 /* x × log₂(e) */
%7348 = vpow2.f32 %v3170 /* 2^result = exp(x) */
%v10295 = vpop.eup %7348 /* pop from EUP */
3. Running Max Update (the key online softmax insight):
// m_next = max(m_prev, m_curr)
%v10156 = vmax.f32 %v10130, %v2539 /* running max update */
// scores - max (for numerical stability)
%v2918 = vsub.f32 %v9345, %v10156
%v2919 = vsub.f32 %v9347, %v10156
Final VLIW Bundles (Pass 79)
The final bundles show aggressive instruction-level parallelism. Here’s the entry bundle:
0x0 : {
%s8468_s30 = smov [#allocation3] ;; // Load SMEM constant addresses
%s8469_s12 = smov [#allocation4] ;;
%s14794_s0 = inlined_call_operand.hbm [shape: s8[1,4,4], index: 0] ;; // block_mask
%s14795_s2 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 2] ;; // Q
%s14796_s3 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 3] ;; // K
%s14797_s4 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 4] ;; // V
...
} /* entry bundle */
The Online Softmax Interleaving
The most impressive bundles show matmul, softmax, and mask operations executing in parallel:
0x20b : > {
%v2535_v40 = vmax.f32 %v9345, %v9347 ;; // Tree reduction: max
%v1281_v41 = vpop.f32.mrf.mxu0 ;; // Pop Q@K^T result from MXU0
%7068 = vmatmul.mubr.bf16.gmra.mxu0 %v6597 ;; // Start next matmul on MXU0
%v2532_v42 = vmax.f32 %v2531, %v9343 ;; // Continue max reduction
%v9365_v48 = vsel %vm2025, %v1630, -2.38e+38 ;; // Apply causal mask
%v1960_v34 = vld [vmem:[%s8897_s21 + $0x30]] ;; // Load next Q sequence index
%vm2041 = vcmp.ge.s32 %v1959, %v9314 ;; // Compute next mask predicate
}
In a single VLIW bundle:
Pop matmul result from MXU
Start next matmul
Tree reduction for running max
Apply causal mask via
vselLoad next tile
Compute next mask predicate
This is the power of Pallas—the online softmax operations interleave with matmul, eliminating the HBM round-trips that plague reference attention.
Dual-MXU Distribution (Final Bundles)
The VLIW scheduler distributes work across both MXUs:
0x20c : > {
%v1634_v45 = vpop.f32.mrf.mxu1 ;; // Pop from MXU1
%7148 = vmatmul.mubr.bf16.gmra.mxu1 %v6597 ;; // Matmul on MXU1
%7069 = vmatprep.mubr.bf16.mxu0 %v9334 ;; // Prepare MXU0
%v9367_v49 = vsel %vm2026, %v1281, -2.38e+38 ;; // Mask application
%vm2042 = vcmp.ge.s32 %v1960, %v9305 ;;
}
What Makes This Possible
The key insight from the LLO is that everything happens in one loop:
DMA prefetch loads next Q/K/V tiles while current tiles compute
Matmul computes attention scores (Q @ K^T)
Online softmax updates running max/sum/output in the same iteration
S @ V accumulates output using rescaled softmax weights
Scratch buffers in VMEM carry state across KV blocks
Reference attention requires three separate fusions with HBM materialization between them. Pallas keeps everything in VMEM—the 128MB attention matrix never exists.
The Compiler’s Role
The TPU Pallas compiler handled:
MXU assignment (pass 13): Initially all on mxu0
VLIW bundle scheduling (pass 29-33): Distributes across both MXUs
DMA scheduling (pass 35-37): Overlaps memory access with compute
Register allocation (pass 45-55): Manages VMEM pressure
Final bundle packing (pass 73-79): Maximizes ILP
The Pallas programmer writes the algorithm; the compiler handles the hardware mapping. But the algorithm itself—online softmax with tiled accumulation—is something no compiler can currently invent. That’s why Pallas exists.
Reference Attention Jax
The reference implementation lives in JAX’s Splash Attention library. It’s the naive O(n²) memory baseline that materializes the full attention matrix:
def _attention_reference_default(mask, q, k, v, segment_ids, mask_value, ...):
# Q @ K^T -> full [seq, seq] attention matrix
logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32))
# Apply causal mask - masked positions get a large negative value
logits = jnp.where(mask, logits, mask_value) # mask_value = -2.38e+38
# Numerically stable softmax
m = logits.max(axis=-1) # reduce_max over key dimension
s = jnp.exp(logits - m[..., None]) # subtract max for stability
l = s.sum(axis=-1) # reduce_sum for normalization
s = s / l[..., None] # normalize to probabilities
# Output projection: S @ V
o = jnp.einsum("st,td->sd", s, v.astype(jnp.float32))
return o
Reference Attention HLO
Initial HLO (Before Optimization)
XLA traces the reference implementation into standard HLO ops. Here’s the complete structure:
HloModule jit_reference_attention, entry_computation_layout={(bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)})->f32[8,2048,128]{2,1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}
%region_0.25 (Arg_0.22: f32[], Arg_1.23: f32[]) -> f32[] {
%Arg_0.22 = f32[] parameter(0), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max"}
%Arg_1.23 = f32[] parameter(1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max"}
ROOT %maximum.24 = f32[] maximum(%Arg_0.22, %Arg_1.23), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=196}
}
%region_1.36 (Arg_0.33: f32[], Arg_1.34: f32[]) -> f32[] {
%Arg_0.33 = f32[] parameter(0), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum"}
%Arg_1.34 = f32[] parameter(1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum"}
ROOT %add.35 = f32[] add(%Arg_0.33, %Arg_1.34), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=198}
}
ENTRY %main.47 (Arg_0.1: bf16[8,2048,128], Arg_1.2: bf16[8,2048,128], Arg_2.3: bf16[8,2048,128]) -> f32[8,2048,128] {
%constant.4 = pred[8,2048,2048]{2,1,0} constant({...})
%Arg_0.1 = bf16[8,2048,128]{2,1,0} parameter(0), metadata={op_name="q"}
%convert.0 = f32[8,2048,128]{2,1,0} convert(%Arg_0.1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/convert_element_type" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=184}
%Arg_1.2 = bf16[8,2048,128]{2,1,0} parameter(1), metadata={op_name="k"}
%convert.1 = f32[8,2048,128]{2,1,0} convert(%Arg_1.2), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/convert_element_type" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=184}
%dot.0 = f32[8,2048,2048]{2,1,0} dot(%convert.0, %convert.1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(sd,td->st)/dot_general" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=184}
%constant.0 = f32[] constant(-2.38197633e+38)
%broadcast.1 = f32[8,2048,2048]{2,1,0} broadcast(%constant.0), dimensions={}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}
%select.1 = f32[8,2048,2048]{2,1,0} select(%constant.4, %dot.0, %broadcast.1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/select_n" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}
%constant.1 = f32[] constant(-inf)
%reduce.0 = f32[8,2048]{1,0} reduce(%select.1, %constant.1), dimensions={2}, to_apply=%region_0.25, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=196}
%reshape.0 = f32[8,2048,1]{2,1,0} reshape(%reduce.0), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%broadcast.2 = f32[8,2048,1]{2,1,0} broadcast(%reshape.0), dimensions={0,1,2}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%reshape.1 = f32[8,2048]{1,0} reshape(%broadcast.2), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%broadcast.3 = f32[8,2048,2048]{2,1,0} broadcast(%reshape.1), dimensions={0,1}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%subtract.0 = f32[8,2048,2048]{2,1,0} subtract(%select.1, %broadcast.3), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%exponential.0 = f32[8,2048,2048]{2,1,0} exponential(%subtract.0), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/exp" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%constant.2 = f32[] constant(0)
%reduce.1 = f32[8,2048]{1,0} reduce(%exponential.0, %constant.2), dimensions={2}, to_apply=%region_1.36, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=198}
%reshape.2 = f32[8,2048,1]{2,1,0} reshape(%reduce.1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
%broadcast.4 = f32[8,2048,1]{2,1,0} broadcast(%reshape.2), dimensions={0,1,2}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
%reshape.3 = f32[8,2048]{1,0} reshape(%broadcast.4), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
%broadcast.5 = f32[8,2048,2048]{2,1,0} broadcast(%reshape.3), dimensions={0,1}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
%divide.0 = f32[8,2048,2048]{2,1,0} divide(%exponential.0, %broadcast.5), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
%Arg_2.3 = bf16[8,2048,128]{2,1,0} parameter(2), metadata={op_name="v"}
%convert.2 = f32[8,2048,128]{2,1,0} convert(%Arg_2.3), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/convert_element_type" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=201}
ROOT %dot.1 = f32[8,2048,128]{2,1,0} dot(%divide.0, %convert.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(st,td->sd)/dot_general" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=201}
}
Notice: XLA sees every operation—two dots, a select, two reductions, elementwise ops. This is fundamentally different from Pallas, where XLA sees one opaque custom-call.
Optimized HLO (After Fusion)
After optimization passes, XLA fuses operations into three main kernels plus helper fusions. Here’s the complete optimized HLO:
HloModule jit_reference_attention, is_scheduled=true, entry_computation_layout={(bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)})->f32[8,2048,128]{2,1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}
%copy_fusion.2 (input.2: pred[8,2048,2048]) -> pred[8,2048,2048] {
%input.2 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)} parameter(0)
ROOT %copy.4 = pred[8,2048,2048]{1,2,0:T(8,128)(4,1)} copy(%input.2)
}
%fused_computation.1 (param_0.19: f32[8,2048], param_1.20: f32[8,2048], param_2.14: pred[8,2048,2048], param_3.4: f32[8,2048,2048]) -> f32[8,2048,2048] {
%param_2.14 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)} parameter(2)
%fusion.12 = pred[8,2048,2048]{1,2,0:T(8,128)(4,1)} fusion(%param_2.14), kind=kLoop, output_to_operand_aliasing={{}: (0, {})}, calls=%copy_fusion.2
%param_3.4 = f32[8,2048,2048]{1,2,0:T(8,128)} parameter(3)
%constant.21 = f32[]{:T(128)} constant(-2.38197633e+38)
%broadcast.21 = f32[8,2048,2048]{1,2,0:T(8,128)} broadcast(%constant.21), dimensions={}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"output_chunk_bound_config":{"output_chunk_bound":[]}}
%select.5 = f32[8,2048,2048]{1,2,0:T(8,128)} select(%fusion.12, %param_3.4, %broadcast.21), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/select_n" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}
%param_1.20 = f32[8,2048]{1,0:T(8,128)S(1)} parameter(1)
%broadcast.15 = f32[8,2048,2048]{1,2,0:T(8,128)} broadcast(%param_1.20), dimensions={0,1}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"output_chunk_bound_config":{"output_chunk_bound":["1","128"]}}
%subtract.4 = f32[8,2048,2048]{1,2,0:T(8,128)} subtract(%select.5, %broadcast.15), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%exponential.4 = f32[8,2048,2048]{1,2,0:T(8,128)} exponential(%subtract.4), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/exp" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%param_0.19 = f32[8,2048]{1,0:T(8,128)S(1)} parameter(0)
%broadcast.11 = f32[8,2048,2048]{1,2,0:T(8,128)} broadcast(%param_0.19), dimensions={0,1}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"output_chunk_bound_config":{"output_chunk_bound":["1","128"]}}
ROOT %divide.2 = f32[8,2048,2048]{1,2,0:T(8,128)} divide(%exponential.4, %broadcast.11), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
}
%bitcast_fusion.1 (bitcast_input.1: bf16[8,2048,128]) -> bf16[8,2048,128] {
%bitcast_input.1 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(0)
ROOT %bitcast.1 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.1)
}
%fused_computation (param_0.1: bf16[8,2048,128], param_1.18: f32[8,2048], param_2.12: f32[8,2048], param_3.3: pred[8,2048,2048], param_4: f32[8,2048,2048]) -> f32[8,2048,128] {
%param_1.18 = f32[8,2048]{1,0:T(8,128)S(1)} parameter(1)
%param_2.12 = f32[8,2048]{1,0:T(8,128)S(1)} parameter(2)
%param_3.3 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)} parameter(3)
%param_4 = f32[8,2048,2048]{1,2,0:T(8,128)} parameter(4)
%fusion.1 = f32[8,2048,2048]{1,2,0:T(8,128)} fusion(%param_1.18, %param_2.12, %param_3.3, %param_4), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/div" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=199}
%param_0.1 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(0)
%fusion.8 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1), kind=kLoop, calls=%bitcast_fusion.1
ROOT %convolution-base-dilated.2 = f32[8,2048,128]{2,1,0:T(8,128)} convolution(%fusion.1, %fusion.8), window={size=8 stride=7 lhs_dilate=8}, dim_labels=0bf_0io->0bf, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(st,td->sd)/dot_general" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=201}
}
%region_1.36 (Arg_0.33: f32[], Arg_1.34: f32[]) -> f32[] {
%Arg_1.34 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum"}
%Arg_0.33 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum"}
ROOT %add.35 = f32[]{:T(128)} add(%Arg_0.33, %Arg_1.34), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=198}
}
%copy_fusion.1 (input.1: pred[8,2048,2048]) -> pred[8,2048,2048] {
%input.1 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)} parameter(0)
ROOT %copy.3 = pred[8,2048,2048]{1,2,0:T(8,128)(4,1)} copy(%input.1)
}
%fused_computation.3 (param_0.24: f32[8,2048], param_1.26: pred[8,2048,2048], param_2.19: f32[8,2048,2048]) -> f32[8,2048] {
%param_1.26 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)} parameter(1)
%fusion.11 = pred[8,2048,2048]{1,2,0:T(8,128)(4,1)} fusion(%param_1.26), kind=kLoop, output_to_operand_aliasing={{}: (0, {})}, calls=%copy_fusion.1
%param_2.19 = f32[8,2048,2048]{1,2,0:T(8,128)} parameter(2)
%constant.16 = f32[]{:T(128)} constant(-2.38197633e+38)
%broadcast.23 = f32[8,2048,2048]{1,2,0:T(8,128)} broadcast(%constant.16), dimensions={}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"output_chunk_bound_config":{"output_chunk_bound":[]}}
%select.7 = f32[8,2048,2048]{1,2,0:T(8,128)} select(%fusion.11, %param_2.19, %broadcast.23), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/select_n" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}
%param_0.24 = f32[8,2048]{1,0:T(8,128)S(1)} parameter(0)
%broadcast.16 = f32[8,2048,2048]{1,2,0:T(8,128)} broadcast(%param_0.24), dimensions={0,1}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"output_chunk_bound_config":{"output_chunk_bound":["1","128"]}}
%subtract.6 = f32[8,2048,2048]{1,2,0:T(8,128)} subtract(%select.7, %broadcast.16), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/sub" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%exponential.6 = f32[8,2048,2048]{1,2,0:T(8,128)} exponential(%subtract.6), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/exp" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=197}
%constant.14 = f32[]{:T(128)} constant(0)
ROOT %reduce.2 = f32[8,2048]{1,0:T(8,128)S(1)} reduce(%exponential.6, %constant.14), dimensions={2}, to_apply=%region_1.36, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=198}
}
%region_0.25 (Arg_0.22: f32[], Arg_1.23: f32[]) -> f32[] {
%Arg_1.23 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max"}
%Arg_0.22 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max"}
ROOT %maximum.24 = f32[]{:T(128)} maximum(%Arg_0.22, %Arg_1.23), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=196}
}
%bitcast_fusion (bitcast_input: bf16[8,2048,128]) -> bf16[8,2048,128] {
%bitcast_input = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(0)
ROOT %bitcast = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input)
}
%bitcast_fusion.2 (bitcast_input.2: bf16[8,2048,128]) -> bf16[8,2048,128] {
%bitcast_input.2 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} parameter(0)
ROOT %bitcast.2 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.2)
}
%copy_fusion (input: pred[8,2048,2048]) -> pred[8,2048,2048] {
%input = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)} parameter(0)
ROOT %copy.2 = pred[8,2048,2048]{1,2,0:T(8,128)(4,1)} copy(%input)
}
%fused_computation.7 (param_0.25: pred[8,2048,2048], param_1.28: bf16[8,2048,128], param_2.21: bf16[8,2048,128]) -> (f32[8,2048], f32[8,2048,2048]) {
%param_0.25 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)} parameter(0)
%fusion.10 = pred[8,2048,2048]{1,2,0:T(8,128)(4,1)} fusion(%param_0.25), kind=kLoop, output_to_operand_aliasing={{}: (0, {})}, calls=%copy_fusion
%param_1.28 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(1)
%fusion.7 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} fusion(%param_1.28), kind=kLoop, calls=%bitcast_fusion
%param_2.21 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} parameter(2)
%fusion.9 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} fusion(%param_2.21), kind=kLoop, calls=%bitcast_fusion.2
%convolution-base-dilated.3 = f32[8,2048,2048]{1,2,0:T(8,128)} convolution(%fusion.7, %fusion.9), window={size=8 stride=7 lhs_dilate=8}, dim_labels=0bf_0oi->0bf, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(sd,td->st)/dot_general" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=184}
%constant.22 = f32[]{:T(128)} constant(-2.38197633e+38)
%broadcast.25 = f32[8,2048,2048]{1,2,0:T(8,128)} broadcast(%constant.22), dimensions={}, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/broadcast_in_dim" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"output_chunk_bound_config":{"output_chunk_bound":[]}}
%select.9 = f32[8,2048,2048]{1,2,0:T(8,128)} select(%fusion.10, %convolution-base-dilated.3, %broadcast.25), metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(jit(_where))/select_n" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=195}
%constant.15 = f32[]{:T(128)} constant(-inf)
%reduce.3 = f32[8,2048]{1,0:T(8,128)S(1)} reduce(%select.9, %constant.15), dimensions={2}, to_apply=%region_0.25, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=196}
ROOT %tuple = (f32[8,2048]{1,0:T(8,128)S(1)}, f32[8,2048,2048]{1,2,0:T(8,128)}) tuple(%reduce.3, %convolution-base-dilated.3)
}
ENTRY %main.47 (Arg_0.1: bf16[8,2048,128], Arg_1.2: bf16[8,2048,128], Arg_2.3: bf16[8,2048,128]) -> f32[8,2048,128] {
%Arg_0.1 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} parameter(0), metadata={op_name="q"}
%copy-start = (bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(%Arg_0.1), cross_program_prefetch_index=0
%constant.4 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)} constant({...})
%Arg_2.3 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} parameter(2), metadata={op_name="v"}
%Arg_1.2 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)} parameter(1), metadata={op_name="k"}
%copy-start.1 = (pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)}, pred[8,2048,2048]{1,2,0:T(32,128)(4,1)}, u32[]{:S(2)}) copy-start(%constant.4)
%copy-done = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)} copy-done(%copy-start)
%fusion.5 = (f32[8,2048]{1,0:T(8,128)S(1)}, f32[8,2048,2048]{1,2,0:T(8,128)}) fusion(%constant.4, %copy-done, %Arg_1.2), kind=kOutput, calls=%fused_computation.7, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(sd,td->st)/dot_general" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=184}
%get-tuple-element.1 = f32[8,2048,2048]{1,2,0:T(8,128)} get-tuple-element(%fusion.5), index=1, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=196}
%get-tuple-element = f32[8,2048]{1,0:T(8,128)S(1)} get-tuple-element(%fusion.5), index=0, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_max" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=196}
%copy-start.2 = (bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)}, bf16[8,2048,128]{2,1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(%Arg_2.3)
%copy-done.1 = pred[8,2048,2048]{1,2,0:T(32,128)(4,1)S(1)} copy-done(%copy-start.1)
%fusion.2 = f32[8,2048]{1,0:T(8,128)S(1)} fusion(%get-tuple-element, %copy-done.1, %get-tuple-element.1), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/reduce_sum" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=198}
%copy-done.2 = bf16[8,2048,128]{2,1,0:T(8,128)(2,1)S(1)} copy-done(%copy-start.2)
ROOT %fusion = f32[8,2048,128]{2,1,0:T(8,128)} fusion(%copy-done.2, %fusion.2, %get-tuple-element, %copy-done.1, %get-tuple-element.1), kind=kOutput, calls=%fused_computation, metadata={op_name="jit(reference_attention)/jit(main)/reference_attention/reference_attention_h8_s2048/jit(_wrapped)/vmap(st,td->sd)/dot_general" source_file="/home/ptoulme/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py" source_line=201}
}
Key Optimization Patterns
1. Multi-Output Fusion (fusion.5)
The first matmul fuses with mask application and reduce_max. It returns a tuple with both outputs:
f32[8,2048]— the row-wise maximum valuesf32[8,2048,2048]— the full attention scores
2. Recomputation vs Storage Trade-off
Notice that exp(scores - max) is computed three times:
In
fusion.2to compute the sumIn
fused_computation.1(insidefusion) to compute normalized weightsThe mask is also re-applied multiple times
This is intentional: recomputing is cheaper than storing and reloading the full [8,2048,2048] tensor.
The O(n²) Problem is Visible:
%get-tuple-element.1 = f32[8,2048,2048]{1,2,0:T(8,128)} get-tuple-element(%fusion.5), index=1This 128MB tensor (8 × 2048 × 2048 × 4 bytes) lives in HBM between fusion.5 and downstream fusions. Pallas avoids this entirely with online softmax.
Reference Attention LLO
After HLO optimization, each fusion compiles through the LLO (Low-Level Operations) pipeline. The fusion.5 kernel alone goes through 79 optimization passes before producing final VLIW bundles.
Original LLO: fusion.5 (Pass 02)
The initial LLO reveals the kernel’s structure before optimization:
$region0: #{fusion.5}
// VMEM allocations for intermediate results
#allocation0 [shape = 'f32[262144]{0}', space=vmem, size = 0x100000] // 1MB scratch
#allocation7 [shape = 'f32[131072]{0}', space=vmem, size = 0x80000] // 512KB scratch
#allocation8 [shape = 's32[1]{0}', space=sflag, size = 0x4] // sync flag
// Kernel inputs/outputs (from HLO fusion interface)
%s0 = inlined_call_operand.hbm [shape: pred[8,2048,2048], index: 0] // mask from HBM
%s1 = inlined_call_operand.vmem [shape: bf16[8,2048,128], index: 1] // Q in VMEM
%s2 = inlined_call_operand.hbm [shape: bf16[8,2048,128], index: 2] // K from HBM
%s3 = inlined_call_operand.vmem [shape: f32[8,2048], index: 3] // max output (VMEM)
%s4 = inlined_call_operand.hbm [shape: f32[8,2048,2048], index: 4] // scores output (HBM!)
// Initialize max output to -inf (16 vector stores for 8×2048 shape)
%7 = vst [vmem:[%s3] sm:$0xff] /*vst_source=*/-inf
%s8 = scalar_lea.vmem %s3, 8
%9 = vst [vmem:[%s8] sm:$0xff] /*vst_source=*/-inf
// ... 14 more stores to initialize max buffer
Key observations:
Q lives in VMEM but K streams from HBM via DMA
The full attention scores write to HBM — 128MB output!
The causal mask also streams from HBM — 32MB of predicates
Loop Structure (Pass 02)
$region2: #{fusion.5} parent=0
// Allocations for double-buffered DMA
#allocation1 [shape = 'u8[524288]{0}', space=vmem, tag = 'operand span for K'] // 512KB
#allocation2 [shape = 's32[2]{0}', space=sflag] // sync flags
#allocation4 [shape = 'u8[4194304]{0}', space=vmem, tag = 'operand span for mask'] // 4MB
#allocation6 [shape = 'u8[16777216]{0}', space=vmem, tag = 'operand span for output'] // 16MB
// Initialize sync flags for double buffering
%38 = vsyncpa [#allocation2], 0
%40 = vsyncpa [#allocation2 + $0x1], 0
loop: start=0, step=1, limit=18
$region4: #{fusion.5} parent=2 // loop_header
// Phi nodes track iteration state across loop iterations
%s48 = sphi 0, %s52 /* iteration index, stage = 0 */
%p49 = scmp.ge.s32.totalorder %s48, 18 /* loop exit test */
// Multi-dimensional iteration bounds (8 heads × 2 K-blocks)
%s55 = sphi 0, %s88 /* iter bound = 0 (head index) */
%s56 = sphi 0, %s84 /* iter bound = 1 (K block index) */
// ... more phi nodes for software pipelining stages
Loop dimensions: The kernel tiles K into 2 blocks of 1024 rows each:
8 heads × 2 K-blocks = 16 main iterations
Plus 2 for software pipelining = 18 total iterations
Matmul Operations (Pass 02 - Original)
Before optimization, each matmul tile uses explicit load-store patterns:
// Load Q slice, unpack bf16, push to MXU
%v322 = vld [vmem:[%s321] sm:$0xf] // Load Q tile (bf16)
%v323 = vunpack.c.l.bf16 %v322 // Unpack low bf16 to f32
%325 = vst [vmem:[%s320] sm:$0xff] /*vst_source=*/%v323 // Store unpacked
%v326 = vld [vmem:[%s320] sm:$0xff] // Reload (wasteful!)
%327 = vmatpush1.xpose.msra.mxu0 %v326 // Push to MXU systolic array
// Prepare RHS with zeros (accumulator init)
%319 = vmatprep.subr.mxu0 0.0 // Prepare zero for subtraction
// Repeat for all 16 tiles of the 128-dim K dimension...
%328 = vmatprep.subr.mxu0 0.0
%v331 = vld [vmem:[%s330] sm:$0xf]
%v332 = vunpack.c.l.bf16 %v331
%334 = vst [vmem:[%s329] sm:$0xff] /*vst_source=*/%v332
%v335 = vld [vmem:[%s329] sm:$0xff]
%336 = vmatpush1.xpose.msra.mxu0 %v335
// ... continues for all K tiles
This is verbose and inefficient — later passes eliminate redundant loads/stores.
VLIW Bundle Packing (Pass 29)
The bundle packer groups independent operations into VLIW bundles:
0x1 : { %7 = vst [vmem:[%s3] sm:$0xff] /*vst_source=*/%v57179 ;;
%51239 = vst [vmem:[%s3 + $0x8] sm:$0xff] /*vst_source=*/%v57179 ;;
%51240 = vst [vmem:[%s3 + $0x10] sm:$0xff] /*vst_source=*/%v57179 ;;
// ... 16 parallel stores in one bundle!
}
// Later: DMA + address computation in parallel
0x1e : > { %57143 = dma.hbm_to_vmem [thread:$0] (!%p57141), /*hbm=*/%s192,
/*size_in_granules=*/8192, /*vmem=*/%s195, /*dst_syncflagno=*/%s180 }
0x1f : > { %p222 = pnand %p51265, %p221 ;;
%s202 = scalar_lea.vmem [#allocation4], %s51260 }
Final VLIW Bundles (Pass 79) - Matmul Section
After all optimizations, both MXUs execute in parallel:
// MXU operations now distributed across both units
0xa : > { %55856 = vmatprep.subr.mxu0 %v465 ;;
%56016 = vmatprep.subr.mxu1 %v7203 }
0xb : > { %55857 = vmatpush3.xpose.msra.mxu0 %v323 ;;
%56017 = vmatpush3.xpose.msra.mxu1 %v7059 ;;
%v332 = vunpack.c.h.bf16 %v51273 ;;
%v7068 = vunpack.c.h.bf16 %v52005 }
0xc : > { %55858 = vmatprep.subr.mxu0 %v474 ;;
%56018 = vmatprep.subr.mxu1 %v7212 ;;
%v483 = vunpack.c.l.bf16 %v51288 ;;
%v7221 = vunpack.c.h.bf16 %v52022 }
Key optimization: The VLIW scheduler distributes work across both MXUs:
mxu0: 1,396 operations (Q @ K^T lower tiles)
mxu1: 1,383 operations (Q @ K^T upper tiles)
This effectively doubles throughput compared to single-MXU execution.
Final VLIW Bundles - Reduce_Max + Mask Application
The most complex bundles overlap matmul results with mask lookup and max reduction:
0xf4 : > { %v618_v36 = vpop.f32.mrf.mxu0 ;; // Pop matmul result
%55937 = vmatmul.mubr.bf16.gmra.mxu0 %v57834_v53 } // Start next matmul
0xf5 : > { %52037 = vst [vmem:[%s57903_s21 + $0x10] sm:$0xff] /*vst_source=*/%v7347_v37 ;;
%56098 = vmatprep.mubr.bf16.mxu1 %v57863_v16 ;; // Prepare mxu1
%v631_v54 = vpop.f32.mrf.mxu0 ;; // Pop another result
%v7358_v6 = vsel %vm, %v7347, -2.38e+38 } // Apply mask!
0xf9 : > { %v7392_v33 = vsel %vm, %v7381, -2.38e+38 ;; // Mask application
%52041 = vst [vmem:[%s57903_s21 + $0x90] sm:$0xff] ;; // Store to HBM buffer
%v665_v37 = vmax.f32 %v627_v12, %v659_v20 ;; // Tree reduction max
%v669_v44 = vpop.f32.mrf.mxu0 ;; // Pop matmul result
%vm58016_vm2 = vcmp.ne.s32.totalorder %v854, 0 } // Prepare next mask
0xfa : > { %v7399_v43 = vmax.f32 %v7358_v6, %v7392_v33 ;; // Continue max tree
%v7403_v53 = vpop.f32.mrf.mxu1 ;; // Pop from mxu1
%v680_v59 = vsel %vm, %v669, -2.38e+38 ;; // Mask more tiles
%51310 = vst ... } // Store result
In a single VLIW bundle (0xf9):
Pop matmul result from MXU
Apply mask via
vselStore to HBM buffer
Compute tree reduction
vmax.f32Prepare next mask predicate
This is 5 independent operations in parallel — the power of VLIW scheduling.
Comparison: Pallas vs Reference Jax
Now we can directly compare the two approaches using our test configuration: (8 heads, 2048 seq_len, 128 head_dim).
Algorithm: Online vs Naive Softmax
The fundamental difference lies in the softmax computation strategy:
The reference implementation computes the full [8, 2048, 2048] attention matrix (128MB in f32), stores it to HBM, then reads it back multiple times for softmax operations. Pallas processes small tiles (e.g., [512, 512] or [1024, 1024]) that fit in VMEM, computing and consuming each tile before moving to the next.
Kernel Structure
Reference: 3 Separate Fusions
fusion.5: Q @ K^T + mask + reduce_max → 2,559 bundles
fusion.2: exp + reduce_sum → 1,925 bundles
fusion: normalize + S @ V → 5,761 bundles
─────────────────────────────────────────────────────────
Total: ~10,245 bundlesEach fusion requires:
Reading inputs from HBM
Computing results
Writing outputs back to HBM
Synchronization before next fusion
Pallas: 1 Custom Kernel (varies by block size)
splash_mha_fwd (block_size=512) → 1,651 bundles
splash_mha_fwd (block_size=1024) → 4,302 bundles
splash_mha_fwd (block_size=2048) → 15,153 bundlesEverything happens in a single kernel with explicit VMEM management. No intermediate HBM round-trips.
VLIW Bundle Comparison
The block size tradeoff is fascinating:
block_size=512: More iterations (4×4=16 per head), but each iteration is simple. Fewest total bundles.
block_size=1024: Balanced (2×2=4 per head). Still significantly fewer bundles than reference.
block_size=2048: One iteration covers entire sequence (1×1=1 per head). More bundles than reference because the large tile operations can’t amortize overhead.
HBM Bandwidth Analysis
For our test case (8 heads, 2048 seq_len, 128 head_dim):
Reference HBM Traffic:
Inputs:
Q, K, V: 3 × 8 × 2048 × 128 × 2 bytes = 6 MB
Intermediates (written then read):
S matrix after fusion.5: 8 × 2048 × 2048 × 4 = 128 MB (write)
S matrix for fusion.2: 8 × 2048 × 2048 × 4 = 128 MB (read)
exp(S) after fusion.2: 8 × 2048 × 2048 × 4 = 128 MB (write)
exp(S) for fusion: 8 × 2048 × 2048 × 4 = 128 MB (read)
Output:
O: 8 × 2048 × 128 × 4 = 8 MB
Total: ~6 + 512 + 8 = ~526 MB HBM traffic
Pallas HBM Traffic:
Inputs:
Q, K, V: 3 × 8 × 2048 × 128 × 2 bytes = 6 MB
Intermediates:
None! Everything stays in VMEM
Output:
O: 8 × 2048 × 128 × 4 = 8 MB
Total: ~14 MB HBM trafficThat’s a 37× reduction in HBM bandwidth. Since TPU performance is often memory-bound, this directly translates to speedup.
The Block Size Sweet Spot
Why does block_size=512 produce fewer bundles than block_size=1024, and why does block_size=2048 produce more bundles than reference attention?
At block_size=2048 (equal to seq_len), there’s no tiling—you compute the full [2048, 2048] attention matrix in one iteration. You’ve lost the memory benefit of online softmax but kept all its bookkeeping. It’s reference attention with extra overhead.
At block_size=512 vs 1024, the total math is the same—only the schedule changes. Smaller tiles reduce VMEM footprint, which improves double-buffering and DMA/compute overlap. Larger tiles reduce loop count but increase register pressure, making it harder for the backend to keep the MXUs busy while hiding memory latency.
Why Pallas Matters
Last post, the takeaway was: trust the compiler. This post’s takeaway is: know when not to.
XLA is doing real work on reference attention — it fuses the matmul with the mask and max reduction, it overlaps DMA with compute, it balances both MXUs. The 10,245 VLIW bundles represent a well-optimized implementation of the algorithm you wrote. The problem is the algorithm itself.
Pallas doesn’t replace XLA. For most operations, the automatic path is the right path. But when you’re hitting a fundamental algorithmic limit — when you know there’s a streaming formulation the compiler can’t discover — Pallas is how you escape. You write the tiled, memory-conscious kernel; Mosaic and LLO handle the rest.
The practical insights:
Block size matters, but depends on your shapes. For our test case (8 heads, seq_len=2048, head_dim=128), block_size=512 produced 6× fewer bundles than reference attention. block_size=2048 produced more bundles. Smaller tiles pipeline better and reduce register pressure — but the optimal block size depends on your sequence length, head count, and head dimension. Profile your actual workload.
HBM bandwidth is often the bottleneck. Reference attention moved 526MB through HBM; Splash moved 14MB. The 37× reduction in memory traffic is why online softmax wins, not the bundle count.
The compiler still does heavy lifting. Pallas gives you control over algorithm-level tiling; Mosaic handles vector layouts, MXU scheduling, and VLIW packing automatically. You’re not writing assembly.
The dump flags for tracing Pallas kernels:
os.environ["XLA_FLAGS"] = (
f"--xla_dump_hlo_as_text "
f"--xla_dump_to={HLO_PATH} "
f"--xla_dump_hlo_pass_re=.* "
)
os.environ["LIBTPU_INIT_ARGS"] = (
f"--xla_jf_dump_to={LLO_PATH} "
f"--xla_jf_dump_hlo_text=true "
f"--xla_jf_dump_llo_text=true "
f"--xla_jf_dump_llo_html=false "
f"--xla_jf_dump_llo_static_gaps=true "
f"--xla_jf_emit_annotations=true "
f"--xla_jf_debug_level=2 "
f"--xla_mosaic_dump_to={MOSAIC_PATH} "
f"--xla_mosaic_enable_dump_debug_info=true "
f"--xla_mosaic_enable_llo_source_annotations=true"
)The HLO shows an opaque custom-call — the interesting stuff is in the Mosaic MLIR dumps (0001-original.txt through 0012-post-finalize-llo.txt) and the final LLO bundles.
In principle, a compiler could learn online softmax. Someone could teach XLA to recognize the attention pattern, prove the streaming reformulation is valid, and generate the tiled kernel automatically. Until then, we write Pallas.
Questions? Message me on LinkedIn: https://www.linkedin.com/in/patrick-toulme-150b041a5/ or follow me on X: https://x.com/PatrickToulme







Really good, thanks for writing