[CK_TILE] Port hw independent changes from internal repo to develop branch (#3301)

* [CK_TILE] Port hw independent changes from internal repo to develop branch

It includes PR#96, #114, #120, #121.

* correct rebase error

[ROCm/composable_kernel commit: fc7bf0ab1c]
This commit is contained in:
linqunAMD
2025-12-13 01:28:37 +08:00
committed by GitHub
parent 1f2421c944
commit c6ab08a491
22 changed files with 465 additions and 310 deletions

View File

@@ -986,6 +986,8 @@ struct MoeSortingKernel
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
__syncthreads();
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights

View File

@@ -561,6 +561,7 @@ struct GroupedGemmKernel
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_sync_lds();
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)

View File

@@ -631,6 +631,7 @@ struct StreamKKernel
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}
// Stream-K section
@@ -679,8 +680,8 @@ struct StreamKKernel
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;
return num_cu;
@@ -700,7 +701,7 @@ struct StreamKKernel
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
return max(occupancy, 1);

View File

@@ -280,7 +280,7 @@ struct UniversalGemmKernel
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;

View File

@@ -9,11 +9,35 @@
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV1
{
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }

View File

@@ -9,11 +9,34 @@
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV2
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;

View File

@@ -43,13 +43,14 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = false>
bool Preshuffle_ = false,
int VectorSize_ = 16>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = VectorSize_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using AsLayout = AsLayout_;

View File

@@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync
if constexpr(num_reduce_warps == 1)
return;
block_sync_lds();
// Each warp's lane 0 writes its partial results to shared memory
const index_t smem_offset = warp_id;
if(lane_id == 0)

View File

@@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
inline void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
inline void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
inline void
dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
inline void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
inline void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
inline void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_add_rmsnorm2d_rdquant_fwd_json(
const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
inline void
dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json(
END_JSON_DUMP_FILE();
}
void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
inline void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
inline void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
inline void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
inline void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
inline void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
const std::string& qscale,
const std::string& bias,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
inline void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
const std::string& qscale,
const std::string& bias,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
inline void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);