This commit is contained in:
coderfeli
2025-02-19 10:39:12 +00:00
parent 1d46f8f63d
commit a41f76cdea
2 changed files with 15 additions and 29 deletions

View File

@@ -1,9 +1,16 @@
list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
target_compile_options(example_moe_gemm1 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS})
target_compile_options(example_moe_gemm2 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS})
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
add_example_executable(example_moe_pk_i4_gemm1 moe_pk_i4_gemm1.cpp)
add_example_executable(example_moe_pk_i4_gemm2 moe_pk_i4_gemm2.cpp)

View File

@@ -747,40 +747,19 @@ struct GridwiseMoeGemm
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) /APackedSize < 1
? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_tuple(make_xor_with_modulo_transform(
make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
return a_lds_block_desc_permuted;
}
else // ColumnMajor A
{