mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Port hw independent changes from internal repo to develop branch (#3301)
* [CK_TILE] Port hw independent changes from internal repo to develop branch
It includes PR#96, #114, #120, #121.
* correct rebase error
[ROCm/composable_kernel commit: fc7bf0ab1c]
This commit is contained in:
@@ -459,7 +459,7 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args()
|
||||
inline auto create_args()
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
|
||||
@@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>
|
||||
parse_gemm_size(ck_tile::ArgParser& arg_parser)
|
||||
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_gemm_size(
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
|
||||
@@ -986,6 +986,8 @@ struct MoeSortingKernel
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
smem_cumdup(num_experts) = smem_cumsum(num_experts);
|
||||
|
||||
// fill the p_sorted_token_ids/p_sorted_weights
|
||||
|
||||
@@ -561,6 +561,7 @@ struct GroupedGemmKernel
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
|
||||
block_sync_lds();
|
||||
block_id = block_id + grid_size; // advance to next block
|
||||
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
|
||||
if(block_id >= cum_grid_size)
|
||||
|
||||
@@ -631,6 +631,7 @@ struct StreamKKernel
|
||||
tile_idx += kargs.tile_partitioner.get_grid())
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// Stream-K section
|
||||
@@ -679,8 +680,8 @@ struct StreamKKernel
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
ck_tile::hip_check_error(hipGetDevice(&dev));
|
||||
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
int num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return num_cu;
|
||||
@@ -700,7 +701,7 @@ struct StreamKKernel
|
||||
constexpr int min_block_per_cu = 1;
|
||||
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
|
||||
|
||||
hip_check_error(
|
||||
ck_tile::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
|
||||
return max(occupancy, 1);
|
||||
|
||||
@@ -280,7 +280,7 @@ struct UniversalGemmKernel
|
||||
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto kernel = kentry<1, Kernel, KernelArgs>;
|
||||
int occupancy;
|
||||
hip_check_error(
|
||||
ck_tile::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
|
||||
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
|
||||
@@ -9,11 +9,35 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
@@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
@@ -9,11 +9,34 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2
|
||||
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
|
||||
@@ -43,13 +43,14 @@ template <bool kPadM_,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1,
|
||||
bool Preshuffle_ = false>
|
||||
bool Preshuffle_ = false,
|
||||
int VectorSize_ = 16>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
|
||||
@@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
block_sync_lds();
|
||||
// Each warp's lane 0 writes its partial results to shared memory
|
||||
const index_t smem_offset = warp_id;
|
||||
if(lane_id == 0)
|
||||
|
||||
@@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
inline void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
inline void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
inline void
|
||||
dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
inline void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
inline void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
inline void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
inline void
|
||||
dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
inline void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
inline void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
inline void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
inline void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
inline void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
const std::string& qscale,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
inline void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
const std::string& qscale,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
inline void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
|
||||
@@ -130,7 +130,7 @@ auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize;
|
||||
|
||||
std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N
|
||||
<< ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock
|
||||
|
||||
@@ -7,7 +7,7 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp)
|
||||
target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -20,20 +20,21 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue enabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
#endif
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
>;
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes);
|
||||
|
||||
@@ -20,17 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue disabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
#endif
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -23,6 +23,28 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename AccDataType,
|
||||
@@ -103,17 +125,25 @@ class TestCkTileGemmMultiABD : public ::testing::Test
|
||||
DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
using ADataType =
|
||||
ck_tile::remove_cvref_t<std::tuple_element_t<ck_tile::number<0>{}, AsDataType>>;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<ADataType, N_Warp_Tile>();
|
||||
#else
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
|
||||
@@ -13,6 +13,28 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
@@ -80,7 +102,7 @@ struct config_wmma
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
|
||||
@@ -9,7 +9,7 @@ endif()
|
||||
# Use standard asm for rtn bf16 conversion instead of turncate
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -29,9 +29,6 @@ template <typename ALayout_,
|
||||
int M_Warp_val_,
|
||||
int N_Warp_val_,
|
||||
int K_Warp_val_,
|
||||
int M_Warp_Tile_val_,
|
||||
int N_Warp_Tile_val_,
|
||||
int K_Warp_Tile_val_,
|
||||
bool DoubleSmemBuffer_val_,
|
||||
ck_tile::GemmPipelineScheduler Scheduler_val_,
|
||||
PipelineType Pipeline_val_,
|
||||
@@ -50,15 +47,21 @@ struct KernelConfig
|
||||
using EDataType = EDataType_;
|
||||
using DsDataType = ck_tile::tuple<D0DataType_, D1DataType_>;
|
||||
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
static constexpr int K_Tile_ = K_Tile_val_;
|
||||
static constexpr int M_Warp_ = M_Warp_val_;
|
||||
static constexpr int N_Warp_ = N_Warp_val_;
|
||||
static constexpr int K_Warp_ = K_Warp_val_;
|
||||
static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_;
|
||||
static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_;
|
||||
static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_;
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
static constexpr int K_Tile_ = K_Tile_val_;
|
||||
static constexpr int M_Warp_ = M_Warp_val_;
|
||||
static constexpr int N_Warp_ = N_Warp_val_;
|
||||
static constexpr int K_Warp_ = K_Warp_val_;
|
||||
#if CK_TILE_USE_WMMA
|
||||
static constexpr int M_Warp_Tile_ = 16;
|
||||
static constexpr int N_Warp_Tile_ = 16;
|
||||
static constexpr int K_Warp_Tile_ = 16;
|
||||
#else
|
||||
static constexpr int M_Warp_Tile_ = 32;
|
||||
static constexpr int N_Warp_Tile_ = 32;
|
||||
static constexpr int K_Warp_Tile_ = (M_Warp_val_ == 2) ? 16 : 8;
|
||||
#endif
|
||||
static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_;
|
||||
static constexpr auto Scheduler_ = Scheduler_val_;
|
||||
static constexpr PipelineType Pipeline_ = Pipeline_val_;
|
||||
@@ -68,21 +71,21 @@ struct KernelConfig
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
|
||||
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
|
||||
// FP16 A/B/D/E
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
|
||||
// BF16 A/B/D/E
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -50,16 +50,16 @@ struct KernelConfig
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
|
||||
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>,
|
||||
|
||||
#endif
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>,
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
@@ -25,6 +28,7 @@ constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
@@ -101,13 +105,40 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
@@ -115,6 +146,11 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t WaveSize = 32;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
|
||||
constexpr bool SupportVectorSize16 =
|
||||
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
|
||||
constexpr int VectorSize = SupportVectorSize16 ? 16 : 8;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -137,7 +173,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
/*UseStructuredSparsity*/ false,
|
||||
/*Persistent*/ false,
|
||||
/*NumWaveGroups*/ 1,
|
||||
/*Preshuffle*/ true>;
|
||||
/*Preshuffle*/ true,
|
||||
VectorSize>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
@@ -210,6 +247,12 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t WaveSize = 32;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
|
||||
constexpr bool SupportVectorSize16 =
|
||||
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
|
||||
constexpr int VectorSize = SupportVectorSize16 ? 16 : 8;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
@@ -230,7 +273,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
/*UseStructuredSparsity*/ false,
|
||||
/*Persistent*/ true, // Enable persistent mode
|
||||
/*NumWaveGroups*/ 1,
|
||||
/*Preshuffle*/ true>;
|
||||
/*Preshuffle*/ true,
|
||||
VectorSize>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
|
||||
Reference in New Issue
Block a user