mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[BlockScale GEMM] FP8 Blockscale GEMM optimization and ckProfiler (#1913)
* Added two kernel for M=32 problem * Comment the first one * Enable multiply_multiply for Scale_Block_M = 1 for deepseek * Modify the a_thread offset since the A data load is different from B. * edit fp8 ab scale for Scale_Block_M=1 * edit GemmSpec to MNKPadding * enable blockwise pipelie v1 and v2. v1 is work for small K. * add instance for gemm_ab_scale * fix cmakelist of ckProfiler * optimize blockscale gemm. todo: reduce vgpr usage * fix a correctness bug * sanity checked * revert ckprofiler cmake changes * clang format * revert unnecessary changes. * remove commented codes. --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: chenjun <junchen2@amd.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -177,14 +178,57 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -195,7 +239,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always 1
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -208,127 +252,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
@@ -337,6 +267,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
@@ -363,10 +303,11 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
|
||||
// KPerBlock)
|
||||
// {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
|
||||
Reference in New Issue
Block a user