diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 35e8c1be49..22b12eb430 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,7 +36,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} -SUPPORTED_PAGE_SIZE = [1, 16, 1024] +SUPPORTED_PAGE_SIZE = [1, 16, 32, 1024] SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] KV_MEMORY_LAYOUT_ENUM_MAP = { diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt index 669e794e33..cce2c53ba4 100644 --- a/example/ck_tile/09_topk_softmax/CMakeLists.txt +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -1,24 +1,11 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) +target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) - -add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) -target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS}) - -add_executable(tile_example_topk_softmax_decode - topk_softmax_decode.cpp - topk_softmax_decode_api.cpp - topk_softmax_api.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp) -target_include_directories(tile_example_topk_softmax_decode PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/ - ${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/) -target_compile_definitions(tile_example_topk_softmax_decode PRIVATE - CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1 - MOE_SORTING_FMOE_2D_BUF=1) -target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS}) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp deleted file mode 100644 index 74d86b13e0..0000000000 --- a/example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/reduce.hpp" -#include "topk_softmax_decode_api.hpp" -#include "topk_softmax_api.hpp" -#include "moe_sorting_api.hpp" - -#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID -#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 -#endif -#include "ck_tile/host/reference/reference_moe_sorting.hpp" - -// CPU reference: softmax -> topk -> moe_sorting -template -bool reference_fused(const ck_tile::HostTensor& x_host, - ck_tile::index_t topk, - ck_tile::index_t num_experts, - ck_tile::index_t unit_size, - ck_tile::HostTensor& ref_sorted_ids, - ck_tile::HostTensor& ref_sorted_weights, - ck_tile::HostTensor& ref_sorted_expert_ids, - ck_tile::index_t& ref_unit_cnt) -{ - auto probs = ck_tile::reference_softmax(x_host); - - ck_tile::HostTensor topk_vals({1, topk}); - ck_tile::HostTensor topk_idxs({1, topk}); - ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk); - - ck_tile::HostTensor local_expert_mask({num_experts}); - ref_unit_cnt = 0; - ck_tile::reference_moe_sorting( - topk_idxs, - topk_vals, - local_expert_mask, - ref_sorted_ids, - ref_sorted_weights, - ref_sorted_expert_ids, - ref_unit_cnt, - num_experts, - unit_size, - 1, - false, - true); - - return true; -} - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "do CPU validation") - .insert("pr_i", "bf16", "input data type: fp16/bf16") - .insert("e", "128", "number of experts") - .insert("k", "8", "topk") - .insert("unit", "32", "unit_size (block_size_M)") - .insert("model_dim", "7168", "model dimension for moe_buf zeroing") - .insert("seed", "-1", "random seed, -1 = random") - .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -bool run_test(ck_tile::ArgParser args) -{ - int validate = args.get_int("v"); - std::string pr = args.get_str("pr_i"); - int experts = args.get_int("e"); - int topk = args.get_int("k"); - int unit_size = args.get_int("unit"); - int model_dim = args.get_int("model_dim"); - int seed = args.get_int("seed"); - int warmup = args.get_int("warmup"); - int repeat = args.get_int("repeat"); - - if(seed < 0) - seed = std::time(nullptr); - - if(topk > experts) - { - printf("topk %d > experts %d, skip\n", topk, experts); - return false; - } - - int tokens = 1; - int max_num_tokens_padded = topk + experts * unit_size - topk; - int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size; - - // Host tensors - ck_tile::HostTensor x_host({1, experts}); - { - auto rng = ck_tile::FillUniformDistribution_Unique{ - -5.f, 5.f, static_cast(seed)}; - ck_tile::HostTensor row({experts}); - rng(row); - std::copy(row.begin(), row.end(), x_host.begin()); - } - - // ---------- Device buffers (shared) ---------- - ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); - x_dev.ToDevice(x_host.data()); - - // ---------- Fused kernel buffers ---------- - ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType)); - ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType)); - ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType)); - ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType)); - ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType)); - { - std::vector ones(model_dim, 1.0f); - fused_moe_buf.ToDevice(ones.data()); - } - - // ---------- Two-kernel baseline buffers ---------- - ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType)); - ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType)); - ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType)); - ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType)); - ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType)); - ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType)); - ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType)); - int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0); - ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1); - if(ws_size > 0) - base_ws.SetZero(); - - ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat}; - - // ====================== Fused kernel ====================== - topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"}; - topk_softmax_decode_kargs fused_karg{ - x_dev.GetDeviceBuffer(), - experts, - topk, - experts, - true, - fused_sorted_ids.GetDeviceBuffer(), - fused_sorted_weights.GetDeviceBuffer(), - fused_sorted_expert_ids.GetDeviceBuffer(), - fused_num_valid.GetDeviceBuffer(), - fused_moe_buf.GetDeviceBuffer(), - unit_size, - model_dim, - static_cast(sizeof(WeightType))}; - - float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc); - - // ============= Two-kernel baseline: topk + moe_sorting ============= - topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"}; - topk_softmax_kargs ts_karg{ - x_dev.GetDeviceBuffer(), - topk_w_dev.GetDeviceBuffer(), - topk_i_dev.GetDeviceBuffer(), - tokens, - experts, - topk, - experts, - topk}; - - moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0}; - moe_sorting_args ms_arg{ - topk_i_dev.GetDeviceBuffer(), - topk_w_dev.GetDeviceBuffer(), - nullptr, - nullptr, - base_sorted_ids.GetDeviceBuffer(), - base_sorted_weights.GetDeviceBuffer(), - base_sorted_expert_ids.GetDeviceBuffer(), - base_num_valid.GetDeviceBuffer(), - base_moe_buf.GetDeviceBuffer(), - ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr, - tokens, - unit_size, - experts, - topk, - model_dim, - static_cast(sizeof(WeightType))}; - - // Time the two kernels together using launch_kernel with two lambdas - auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1}; - float ms_baseline = ck_tile::launch_kernel( - sc, - [&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); }, - [&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); }); - - float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0; - printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx", - pr.c_str(), - experts, - topk, - unit_size, - ms_fused, - ms_baseline, - speedup); - - if(ms_fused < 0 || ms_baseline < 0) - { - printf(" (not supported)\n"); - return false; - } - - // ====================== Validation ====================== - bool pass = true; - if(validate) - { - ck_tile::HostTensor sorted_ids_host({max_num_tokens_padded}); - ck_tile::HostTensor sorted_weights_host({max_num_tokens_padded}); - ck_tile::HostTensor sorted_expert_ids_host({max_num_m_blocks}); - ck_tile::HostTensor num_valid_host({2}); - std::vector moe_buf_host(model_dim); - - fused_sorted_ids.FromDevice(sorted_ids_host.data()); - fused_sorted_weights.FromDevice(sorted_weights_host.data()); - fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data()); - fused_num_valid.FromDevice(num_valid_host.data()); - fused_moe_buf.FromDevice(moe_buf_host.data()); - - ck_tile::HostTensor ref_sorted_ids({max_num_tokens_padded}); - ck_tile::HostTensor ref_sorted_weights({max_num_tokens_padded}); - ck_tile::HostTensor ref_sorted_expert_ids({max_num_m_blocks}); - IndexType sentinel = static_cast((1 & 0x00ffffff) | ((topk & 0xff) << 24)); - std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel); - std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0)); - std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1); - - ck_tile::index_t ref_unit_cnt = 0; - reference_fused( - x_host, topk, experts, unit_size, - ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt); - - int num_valid_padded = num_valid_host(0); - int num_valid_tokens = num_valid_host(1); - - if(num_valid_padded != ref_unit_cnt) - { - printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt); - pass = false; - } - if(num_valid_tokens != 1) - { - printf(" FAIL:num_valid[1] got %d;", num_valid_tokens); - pass = false; - } - - int n_tiles = num_valid_padded / unit_size; - for(int i = 1; i < n_tiles; i++) - { - if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1)) - { - printf(" FAIL:expert_ids not ascending;"); - pass = false; - break; - } - } - - WeightType wsum = 0; - for(int i = 0; i < topk; i++) - wsum += sorted_weights_host(i * unit_size); - if(std::abs(wsum - 1.0f) > 1e-3f) - { - printf(" FAIL:wsum=%.6f;", static_cast(wsum)); - pass = false; - } - - bool buf_zeroed = std::all_of( - moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; }); - if(!buf_zeroed) - { - printf(" FAIL:moe_buf not zeroed;"); - pass = false; - } - } - - printf(" valid:%s\n", pass ? "y" : "n"); - fflush(stdout); - return pass; -} - -int main(int argc, char** argv) -{ - auto [result, args] = create_args(argc, argv); - if(!result) - return -1; - - std::string pr = args.get_str("pr_i"); - bool r = true; - - if(pr == "fp16") - r &= run_test(args); - else if(pr == "bf16") - r &= run_test(args); - else - { - printf("unsupported pr_i: %s\n", pr.c_str()); - return -1; - } - - return r ? 0 : -1; -} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp deleted file mode 100644 index e593556099..0000000000 --- a/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "topk_softmax_decode_api.hpp" - -#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \ - constexpr ck_tile::index_t ts_experts = experts_; \ - constexpr bool ts_use_softmax = use_softmax_; \ - using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem; \ - using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline; \ - \ - using kernel = ck_tile::TopkSoftmaxDecodeKernel; \ - \ - auto kargs = kernel::MakeKargs(a); \ - \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(); \ - \ - float ave_time = \ - ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ - \ - return ave_time; - -#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \ - if(t.experts <= 8) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \ - } \ - else if(t.experts <= 16) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \ - } \ - else if(t.experts <= 32) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \ - } \ - else if(t.experts <= 64) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \ - } \ - else if(t.experts <= 128) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \ - } \ - else if(t.experts <= 192) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \ - } \ - else if(t.experts <= 256) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \ - } \ - else if(t.experts <= 512) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \ - } \ - else if(t.experts <= 1024) \ - { \ - TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\ - } - -float topk_softmax_decode(topk_softmax_decode_trait t, - topk_softmax_decode_kargs a, - ck_tile::stream_config s) -{ - if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax") - { - using ts_input_type = ck_tile::fp16_t; - using ts_weight_type = float; - using ts_index_type = ck_tile::index_t; - TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true) - } - else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax") - { - using ts_input_type = ck_tile::bf16_t; - using ts_weight_type = float; - using ts_index_type = ck_tile::index_t; - TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true) - } - else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") - { - using ts_input_type = ck_tile::fp16_t; - using ts_weight_type = float; - using ts_index_type = ck_tile::index_t; - TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false) - } - else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid") - { - using ts_input_type = ck_tile::bf16_t; - using ts_weight_type = float; - using ts_index_type = ck_tile::index_t; - TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false) - } - return -1; -} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp b/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp deleted file mode 100644 index 73df8af421..0000000000 --- a/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/topk_softmax.hpp" -#include - -struct topk_softmax_decode_trait -{ - std::string input_type; - std::string weight_type; // currently always float - int experts; - std::string activation; // "softmax" or "sigmoid" -}; - -struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs -{ -}; - -float topk_softmax_decode(topk_softmax_decode_trait t, - topk_softmax_decode_kargs a, - ck_tile::stream_config s); diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp new file mode 100644 index 0000000000..6cfe814876 --- /dev/null +++ b/include/ck_tile/ops/unified_attention.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 937bfdf5f2..551dab5242 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -615,7 +615,7 @@ struct UnifiedAttentionPipeline } else { - auto casted = detail::cvt_pk_bf16_f32(x, y); + auto casted = cvt_pk_bf16_f32(x, y); sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; }