diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py index 99d53dc8c4..e2cc10226a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py @@ -37,6 +37,12 @@ K0_MAX_SUBMAX_MAP = { 256: 256 } +SEQLENQ_MAP = { + "16" : "16", + "32" : "32", + # "64" : "64" +} + FMHA_FWD_DECODE_PIPELINE_MAP = { "decode_qr" : "ck_tile::BlockFmhaFwdDecodePipelineQRKSVS", } @@ -288,7 +294,7 @@ float fmha_fwd_decode(fmha_fwd_decode_traits t, fmha_fwd_decode_args a, const ck """ FMHA_FWD_DECODE_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck})&& (a.seqlen_q <= {F_bm0}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes @@ -346,6 +352,7 @@ class FmhaFwdDecodeApiTrait: f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' + # sequence length as non-fast-changing dimension, we can always relay on instruction level OOB guard @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true @@ -362,12 +369,15 @@ class FmhaFwdDecodeApiTrait: else : return 'true' else: assert False + # head dimension as fast-changing dimension, we assume is multiple of 8 @property def dcheck(self) -> str: if self.pipeline_tag in ['decode_qr', 'qr_nwarp_sshuffle']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' + # if self.skpad == 't' : return 'true' + # else : return 'true' else: assert False @property @@ -376,6 +386,8 @@ class FmhaFwdDecodeApiTrait: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' + # if self.skpad == 't' : return 'true' + # else : return 'true' else: assert False @dataclass @@ -637,19 +649,17 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - # '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '64' : FmhaFwdTileSize(16, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '64' : FmhaFwdTileSize(32, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '64' : FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '64' : FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '64' : FmhaFwdTileSize(256, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '128' : FmhaFwdTileSize(16, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '128' : FmhaFwdTileSize(32, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '128' : FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '128' : FmhaFwdTileSize(128, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '128' : FmhaFwdTileSize(256, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), - # '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '64': { + # # Specialize for different SeqQ + '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64': FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + }, + '128': { + '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64': FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + }, } else: return None @@ -684,6 +694,7 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T for lse in ['t', 'f']: if hdim in [64, 128]: ### [32, 64, 96, 128]: pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 't', 't', logits, bias, lse, squant, pagedkv, mask)) else: assert False else: @@ -698,8 +709,8 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] + for hdim_str, mode, seqlenq in itertools.product(d.keys(), MODE_MAP.keys(), SEQLENQ_MAP.keys()): + tile = d[hdim_str][seqlenq] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): if mode == "group": @@ -762,7 +773,7 @@ def get_fwd_decode_combine_blobs(kernel_filter : Optional[str], receipt) -> List pipelines = [] if dtype in ['fp16', 'bf16']: # for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): - for spad, dvpad, lse in itertools.product(["f"], ["f"], ["t", "f"]): + for spad, dvpad, lse in itertools.product(["f"], ["t", "f"], ["t", "f"]): pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) elif dtype in ['fp8', 'bf8']: # no need lse kernels diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8b07a78d39..0c084564cb 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -696,6 +696,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } + else if(init_method == "exp" || init_method == "99") + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(vnew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } else if(init_method == "nf") { ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); @@ -1136,7 +1145,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush; + << " GB/s" << std::flush << std::endl; if(do_validation == 0) { diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7111eed596..ad23617590 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1315,6 +1315,17 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template (nhead_ratio_qk * seqlen_q, // hdim_q) + // We expect Q data reuse among different KVSplited in decode case. const auto view = make_naive_tensor_view( q_ptr, make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q), @@ -755,7 +756,8 @@ struct FmhaFwdDecodeKernel }(); const auto make_k_dram = [&](const KDataType* data, index_t height) { - const auto k_dram_naive = make_naive_tensor_view( + // We don't expect K data reuse among different blocks in decode case. + const auto k_dram_naive = make_naive_tensor_view( data, // will update this pointer if using paged-kvcache make_tuple(height, kargs.hdim_q), make_tuple(kargs.stride_k, 1), @@ -781,7 +783,8 @@ struct FmhaFwdDecodeKernel const auto make_v_dram = [&](const VDataType* data, index_t length) { if constexpr(std::is_same_v) { - const auto v_dram_naive = make_naive_tensor_view( + // We don't expect V data reuse among different blocks in decode case. + const auto v_dram_naive = make_naive_tensor_view( data, // will update this pointer if using paged-kvcache make_tuple(length, kargs.hdim_v), make_tuple(kargs.stride_v, 1), diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp index 112dd48a01..a95277f620 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp @@ -44,6 +44,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(number<1>{}); static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -546,13 +548,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS } __builtin_amdgcn_sched_barrier(0); + + // In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1 + // Otherwise shuffle through LDS so that the tile layout is consistent with required by Gemm1 + auto s_new = [&](){ + if constexpr ( !((kNWarp==1) && (kNXdl == 32)) ){ + auto s = cast_tile(s_acc); // S{j} - const auto s = cast_tile(s_acc); // S{j} - - // shuffle through LDS so that the tile layout is consistent with required by Gemm1 - store_tile(s_write_lds_window, s); - block_sync_lds(); - auto s_new = load_tile(s_read_lds_window); + store_tile(s_write_lds_window, s); + block_sync_lds(); + return load_tile(s_read_lds_window); + } + else{ + return cast_tile(s_acc); // S{j} + } + }(); auto m_local = block_tile_reduce( s_new, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp index 6521ade6c8..ea499c4e9d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp @@ -157,7 +157,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy constexpr index_t MWarp = config.template at<1>(); constexpr index_t NWarp = config.template at<2>(); - static_assert(MWarp == 1, "Check failed!"); + // static_assert(MWarp == 1, "Check failed!"); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;