mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Dev/a8w4 and a8w8splitk (#3447)
* Ck moe bs splitk pr (#3440) * splitk kick-off. Compilation fail * splitk hack pass * fix scale offset calc. * clang-format for a8w8_moe_blk_gemm1 splitk change * fix testcase error --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> * Zan/moe a8w4 (#3441) * update * update * update ck moe a8w4 * update * update * update * compile pass * update * update * python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready * support new a8w4 kernel * update * update ck_tile * re format * update * update * fix conflict * fix build * update ck_tile moe * fix clang format * fix the problem * fix accruacy issue * fix --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Co-authored-by: Zzz9990 <zanzhang@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -74,6 +74,7 @@ template <typename ALayout,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool IsSplitK = false,
|
||||
bool MulRoutedWeight = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
@@ -156,6 +157,7 @@ struct DeviceMoeGemmBlockScale
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
IsSplitK,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
@@ -201,12 +203,12 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
|
||||
arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
@@ -249,11 +251,12 @@ struct DeviceMoeGemmBlockScale
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
// if(arg_.KBatch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
// 0,
|
||||
// arg_.M * arg_.N * sizeof(CDataType)
|
||||
// * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
// stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
@@ -267,11 +270,12 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
// if(arg.KBatch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
// 0,
|
||||
// arg.M * arg.N * sizeof(CDataType) *
|
||||
// (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
// stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
@@ -289,8 +293,9 @@ struct DeviceMoeGemmBlockScale
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
|
||||
? InMemoryDataOperationEnum::Set
|
||||
: InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
@@ -416,8 +421,8 @@ struct DeviceMoeGemmBlockScale
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
// only impl kbatch 1 for fp32
|
||||
if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -441,6 +446,11 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
|
||||
{
|
||||
// Not support Kpadding with KBatch > 1
|
||||
return false;
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user