diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 8f1ddf84d6..8a75410f7e 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -224,10 +224,18 @@ struct UnifiedAttentionPipeline // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr ck_tile::index_t kAlignmentQ = kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + // The DRAM-view vector length must match the K/V load distribution's + // KVector, which the FA4 decoupling widens to the per-warp-group load count + // (GetK/VLoadNumWarps). Passing the same warp count here keeps the global + // buffer_load width in lock-step with the async-copy descriptors. static constexpr ck_tile::index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? 1 + : Policy::template GetAlignmentK()>(); static constexpr ck_tile::index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? 1 + : Policy::template GetAlignmentV()>(); static constexpr ck_tile::index_t kAlignmentO = kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index beeac94c82..a8b81f69fb 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -63,7 +63,9 @@ struct UnifiedAttentionPipelineDefaultPolicy // // The selector below picks dwordx4 whenever it tiles cleanly and falls // back to dword (matches the historical FP8 path) on the prefill tier. - template + template CK_TILE_DEVICE static constexpr index_t GetKVAlignmentBytes() { #if defined(__gfx950__) @@ -71,7 +73,13 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t tile_elems = Problem::UnifiedAttentionShape::kPageBlockSize * Problem::UnifiedAttentionShape::kHeadDim; - constexpr index_t block_size = Problem::kBlockSize; + // Threads that actually cooperate on this tile's load. Default is the + // whole block (kBlockSize), but the FA4 per-warp-group decoupling has a + // single 4-warp group fill the tile by itself, so the width budget is + // that group's thread count -- which lets the small FP8 prefill tile + // (4 KB / 256 thr = 16 B/thr) finally tile cleanly at dwordx4 instead of + // falling back to 4x as many dword loads. + constexpr index_t block_size = NumLoadThreads; // KVector_elems for 16 B/lane = 16 / ElementSizeInBytes. // NumIssues * KVector_bytes * kBlockSize == tile_bytes, // so the divisibility check is tile_elems * ElementSizeInBytes @@ -91,23 +99,30 @@ struct UnifiedAttentionPipelineDefaultPolicy #endif } - template + // NumWarps = the waves that cooperate on the K/V load (default = the full + // block). The FA4 decoupling loads K/V with a single 4-warp group, so the + // load-path callers pass that group's warp count to widen the load. + template CK_TILE_DEVICE static constexpr auto GetAlignmentK() { using namespace ck_tile; using KDataType = remove_cvref_t; + constexpr index_t NumLoadThreads = NumWarps * get_warp_size(); constexpr index_t MaxReadSizeInBytes = - GetKVAlignmentBytes(); + GetKVAlignmentBytes(); return MaxReadSizeInBytes / sizeof(KDataType); } - template + template CK_TILE_DEVICE static constexpr auto GetAlignmentV() { using namespace ck_tile; using VDataType = remove_cvref_t; + constexpr index_t NumLoadThreads = NumWarps * get_warp_size(); constexpr index_t MaxReadSizeInBytes = - GetKVAlignmentBytes(); + GetKVAlignmentBytes(); return MaxReadSizeInBytes / sizeof(VDataType); } @@ -156,7 +171,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); - constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t KVector = GetAlignmentK(); // this is for global load static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave @@ -197,7 +212,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64 - constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t KVector = GetAlignmentV(); // this is for global load // 4 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); @@ -418,7 +433,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t WarpSize = ck_tile::get_warp_size(); [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds - constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = kKLdsPadInBytes / sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps. @@ -478,7 +493,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t KPack = GetSmemKPackK(); // this is for lds - constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = kKLdsPadInBytes / sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps @@ -602,7 +617,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t WarpSize = ck_tile::get_warp_size(); [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds - constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t KVector = GetAlignmentV(); // this is for global load constexpr index_t kPad = kVLdsPadInBytes / sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps. @@ -661,7 +676,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t KPack = GetSmemVPackK(); // this is for lds - constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = kVLdsPadInBytes / sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps