mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
merge ck moe merge
This commit is contained in:
@@ -1,9 +1,19 @@
|
||||
list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
|
||||
list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
|
||||
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
|
||||
target_compile_options(example_moe_gemm1 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_moe_gemm2 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
|
||||
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
|
||||
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
|
||||
add_example_executable(example_moe_pk_i4_gemm1 moe_pk_i4_gemm1.cpp)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker -g -fverbose-asm)
|
||||
target_compile_options(example_moe_pk_i4_gemm1 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
add_example_executable(example_moe_pk_i4_gemm2 moe_pk_i4_gemm2.cpp)
|
||||
|
||||
@@ -133,8 +133,8 @@ using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 1;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
@@ -174,10 +174,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
4, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
// DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -197,8 +198,6 @@ int main(int argc, char* argv[])
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 544;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -217,6 +216,17 @@ int main(int argc, char* argv[])
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 9) {
|
||||
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
sorted_tile_num = std::stoi(argv[7]);
|
||||
valid_tile_num = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
@@ -227,6 +237,8 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
if (tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
@@ -246,8 +258,10 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
int eids[] = {0, 0,1,2, 3,3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
for (int i = 0; i < sorted_tile_num; i++) {
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
@@ -287,8 +301,8 @@ int main(int argc, char* argv[])
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
@@ -296,12 +310,21 @@ int main(int argc, char* argv[])
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
// d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
// d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -64,7 +64,11 @@ struct MulABScale
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -84,7 +88,11 @@ struct MulABScaleSilu
|
||||
{
|
||||
// act
|
||||
float x0 = 0;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16);
|
||||
#else
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
|
||||
#endif
|
||||
e = ck::type_convert<EDataType>(x0);
|
||||
}
|
||||
};
|
||||
@@ -131,13 +139,13 @@ using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t MXDLPerWave = 1;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
|
||||
@@ -154,8 +162,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
AK1, BK1,
|
||||
MNPerXDL, MNPerXDL,
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
// clang-format on
|
||||
@@ -167,13 +175,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, 128, 128, 64,
|
||||
256, MPerBlock, 128, 128,
|
||||
16, 32,
|
||||
32, 32,
|
||||
4, 1,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
4, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
2, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
@@ -359,6 +367,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
// vector pk_i4x4 permute
|
||||
for(int e = 0; e < experts; e++)
|
||||
{
|
||||
@@ -410,6 +419,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
|
||||
@@ -69,7 +69,12 @@ struct MulABScaleExpertWeight
|
||||
//for real kernel use
|
||||
//warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix
|
||||
(void) d0;
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2);
|
||||
#endif
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
@@ -81,7 +86,11 @@ struct MulABScaleExpertWeight
|
||||
const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -137,7 +146,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t CShuffleNLane = 32;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
@@ -151,7 +160,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
MNPerXDL, MNPerXDL,
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
MXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
// clang-format on
|
||||
@@ -320,6 +329,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, device_op.GetPreShuffleParameters());
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
// vector pk_i4x4 permute
|
||||
for(int e = 0; e < experts; e++)
|
||||
{
|
||||
@@ -371,6 +381,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
@@ -443,8 +454,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(
|
||||
sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k_k,
|
||||
b0_e_n_k,
|
||||
d0_t_n,
|
||||
d1_e_n,
|
||||
d2_e_n,
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
|
||||
@@ -194,17 +194,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
|
||||
|
||||
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
|
||||
constexpr auto staged_num_mfma = num_mfma / MRepeat;
|
||||
constexpr auto staged_num_ds_read_inst_a = ck::math::integer_divide_ceil(num_ds_read_inst_a,MRepeat);
|
||||
constexpr auto staged_num_mfma = ck::math::integer_divide_ceil(num_mfma , MRepeat);
|
||||
|
||||
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
|
||||
constexpr auto staged_num_mfma_per_ds_read_a = ck::math::integer_divide_ceil(staged_num_mfma , staged_num_ds_read_inst_a);
|
||||
|
||||
if constexpr(stage.value == 0)
|
||||
{
|
||||
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
|
||||
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
|
||||
constexpr auto staged_num_mfma_per_buffer_load_b =
|
||||
staged_num_mfma / num_buffer_load_inst_b;
|
||||
constexpr auto staged_num_buffer_load_b_per_ds_read_a = ck::math::integer_divide_ceil(
|
||||
num_buffer_load_inst_b , staged_num_ds_read_inst_a);
|
||||
constexpr auto staged_num_mfma_per_buffer_load_b =ck::math::integer_divide_ceil(
|
||||
staged_num_mfma , num_buffer_load_inst_b);
|
||||
// B global
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
|
||||
@@ -190,7 +190,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
|
||||
@@ -65,16 +65,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
|
||||
const ElementwiseOperation& element_op,
|
||||
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_descs,
|
||||
StaticallyIndexedArray<Index, nSrc>{},
|
||||
dst_descs,
|
||||
StaticallyIndexedArray<Index, nDst>{},
|
||||
element_op,
|
||||
scatter_offsets,
|
||||
scatter_weights)
|
||||
element_op)
|
||||
{
|
||||
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
|
||||
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
|
||||
@@ -129,12 +125,13 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,15 +141,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,10 +158,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs)
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs);
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
template <index_t ISrc>
|
||||
|
||||
@@ -247,6 +247,7 @@ struct DeviceMoeGemm
|
||||
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
|
||||
// now");
|
||||
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
@@ -279,7 +280,6 @@ struct DeviceMoeGemm
|
||||
// }
|
||||
// else
|
||||
{
|
||||
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
// {
|
||||
// const auto kernel = kernel_moe_gemm<
|
||||
@@ -304,8 +304,9 @@ struct DeviceMoeGemm
|
||||
}
|
||||
}
|
||||
}
|
||||
// else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
// {
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
// if(arg.KBatch > 1)
|
||||
// {
|
||||
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
@@ -332,31 +333,29 @@ struct DeviceMoeGemm
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
// {
|
||||
// const auto kernel =
|
||||
// kernel_moe_gemm_gather_2lds<
|
||||
// GridwiseGemm,
|
||||
// true,
|
||||
// IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
|
||||
// minimum_occupancy,
|
||||
// TailNumber::Odd>;
|
||||
// RunKernel(kernel);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// const auto kernel =
|
||||
// kernel_moe_gemm_gather_2lds<
|
||||
// GridwiseGemm,
|
||||
// true,
|
||||
// IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
|
||||
// minimum_occupancy,
|
||||
// TailNumber::Even>;
|
||||
// RunKernel(kernel);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v2 support now");
|
||||
|
||||
@@ -79,13 +79,30 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
|
||||
int lo = amd_assembly_and_b32(q, LO);
|
||||
int hi = amd_assembly_and_b32(q, HI);
|
||||
|
||||
float f32_0 = amd_assemble_cvt_f32_i4(lo);
|
||||
float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
|
||||
float f32_2 = amd_assemble_cvt_f32_i4(hi);
|
||||
float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
|
||||
|
||||
// vector_type<f8_t, 4> res;
|
||||
// res.template AsType<f8x4_t>()(Number<0>{}) = amd_assemble_cvt_f8_f32(f32_1st, f32_2nd, f32_3rd, f32_4th);
|
||||
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
|
||||
}
|
||||
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q)
|
||||
{
|
||||
vector_type<f8_t, 8> res;
|
||||
|
||||
res.template AsType<f8x8_t>()(Number<0>{}) = amd_assembly_i4_to_fp8x2(q);
|
||||
|
||||
return res.template AsType<f8x8_t>()[Number<0>{}];
|
||||
// f8x8_t res;
|
||||
// amd_assembly_i4_to_fp8x8(res, q);
|
||||
// return res;
|
||||
return amd_assembly_i4_to_fp8x8(q);
|
||||
}
|
||||
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
@@ -154,13 +171,55 @@ struct PassThroughPack8
|
||||
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<f8_t, 8> result;
|
||||
y = i4_to_fp8x8(bit_cast<int>(x));
|
||||
|
||||
result.template AsType<f8x8_t>()(Number<0>{}) = i4_to_fp8x8(bit_cast<int>(x));
|
||||
// vector_type<f8_t, 8> result;
|
||||
|
||||
y = result.template AsType<f8x8_t>()[Number<0>{}];
|
||||
// result.template AsType<f8x4_t>()(Number<0>{}) = i4_to_f8x4(bit_cast<int>(x));
|
||||
// result.template AsType<f8x4_t>()(Number<1>{}) = i4_to_f8x4(bit_cast<int>(x) >> 8);
|
||||
|
||||
// y = result.template AsType<f8x8_t>()[Number<0>{}];
|
||||
#else
|
||||
// Added pk_i4_t to f8x2_fnuz_t conversion
|
||||
vector_type<f8_t, 8> dst;
|
||||
vector_type<float, 8> dst_tmp;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
// pk_i4_t to float2_t conversion
|
||||
dst_tmp.template AsType<float2_t>()(Number<0>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
|
||||
dst_tmp.template AsType<float2_t>()(Number<1>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
|
||||
dst_tmp.template AsType<float2_t>()(Number<2>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
|
||||
dst_tmp.template AsType<float2_t>()(Number<3>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
// float to f8_t conversion
|
||||
dst.template AsType<f8_t>()(Number<0>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<0>{}]);
|
||||
dst.template AsType<f8_t>()(Number<1>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<1>{}]);
|
||||
|
||||
dst.template AsType<f8_t>()(Number<2>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<2>{}]);
|
||||
dst.template AsType<f8_t>()(Number<3>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<3>{}]);
|
||||
|
||||
dst.template AsType<f8_t>()(Number<4>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<4>{}]);
|
||||
dst.template AsType<f8_t>()(Number<5>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<5>{}]);
|
||||
|
||||
dst.template AsType<f8_t>()(Number<6>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<6>{}]);
|
||||
dst.template AsType<f8_t>()(Number<7>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<7>{}]);
|
||||
|
||||
y = dst.template AsType<f8x8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -62,39 +62,43 @@ __global__ void
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// template <typename GridwiseGemm,
|
||||
// bool HasMainKBlockLoop,
|
||||
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
// index_t MinimumOccupancy = 1,
|
||||
// TailNumber TailNum = TailNumber::Even>
|
||||
// __global__ void
|
||||
// #if CK_USE_LAUNCH_BOUNDS
|
||||
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
// #endif
|
||||
// // __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
// kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg)
|
||||
// {
|
||||
// #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
// __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
// karg.p_ds_grid,
|
||||
// karg.p_c_grid,
|
||||
// p_shared,
|
||||
// p_shared1,
|
||||
// karg,
|
||||
// karg.a_element_op,
|
||||
// karg.b_element_op,
|
||||
// karg.c_element_op);
|
||||
// #else
|
||||
// ignore = karg;
|
||||
// #endif // end of if (defined(__gfx9__))
|
||||
// }
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -743,40 +747,19 @@ struct GridwiseMoeGemm
|
||||
// in some cases.
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) /APackedSize < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
|
||||
constexpr auto a_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(
|
||||
make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
|
||||
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_ak0_mldslayer_m_ak1,
|
||||
make_tuple(make_pass_through_transform(AK0Number),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return a_lds_block_desc_ak0_m_ak1;
|
||||
return a_lds_block_desc_permuted;
|
||||
}
|
||||
else // ColumnMajor A
|
||||
{
|
||||
@@ -1508,32 +1491,6 @@ struct GridwiseMoeGemm
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
|
||||
constexpr auto EMRepeats = MPerBlock / EMThreads;
|
||||
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
const float *p_sorted_weights_0 = p_ds_grid[I0];
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
|
||||
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
|
||||
if constexpr (IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
} else {
|
||||
const float *p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
// if(threadIdx.x % 16 == 0)
|
||||
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
|
||||
});
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
@@ -1569,9 +1526,7 @@ struct GridwiseMoeGemm
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op,
|
||||
scatter_offsets,
|
||||
scatter_weights};
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
@@ -1600,8 +1555,37 @@ struct GridwiseMoeGemm
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
|
||||
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
|
||||
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float *p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
|
||||
if constexpr (IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
} else {
|
||||
const float *p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
|
||||
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
|
||||
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
@@ -1619,7 +1603,522 @@ struct GridwiseMoeGemm
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf));
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights
|
||||
);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_lds_and_global_step =
|
||||
sfc_cde_block.GetForwardStep(access_id);
|
||||
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
I0,
|
||||
cde_lds_and_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
const index_t* p_max_token_id,
|
||||
const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
void* p_shared1,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
IsInputGemm? problem.NumTokens * problem.TopK : problem.NumTokens , problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
|
||||
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
|
||||
// const index_t b_block_id = blockIdx.x % problem.NBlock;
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if (expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
|
||||
const auto block_mn = [&]() -> std::pair<int, int> {
|
||||
if constexpr (NSwizzle)
|
||||
{
|
||||
// const index_t expert_block_id = blockIdx.x / problem.NBlock; //
|
||||
// const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
|
||||
// const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
|
||||
// const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
|
||||
// const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
|
||||
// const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
|
||||
// const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
|
||||
// if(threadIdx.x==0)
|
||||
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, es, p_sorted_expert_ids[expert_block_id]);
|
||||
|
||||
const index_t ecnt_prefix = p_max_token_id[1+expert_id];
|
||||
const index_t prefix_block = ecnt_prefix * problem.NBlock;
|
||||
const index_t ecnt = p_max_token_id[2+expert_id] - ecnt_prefix;
|
||||
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; //p_max_token_id[expert_id + 1]; // 2
|
||||
const index_t bid_new = blockIdx.x - prefix_block;
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
|
||||
const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
|
||||
// if(threadIdx.x==0)
|
||||
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, ecnt, expert_id);
|
||||
return {nid, mid};
|
||||
} else {
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
}
|
||||
}();
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
|
||||
// if (threadIdx.x==0) {
|
||||
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
|
||||
// }
|
||||
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
|
||||
|
||||
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
|
||||
constexpr auto AKThreads = AK0Threads * AK1Threads;
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr (!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
// if(threadIdx.x==0)
|
||||
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
|
||||
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
1,
|
||||
2>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
|
||||
|
||||
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
|
||||
|
||||
// Blockwise GEMM pipeline
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType *ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if (i.value == 1)
|
||||
{
|
||||
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
|
||||
// if ( threadIdx.x % 16 ==0)
|
||||
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
|
||||
}
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(block_m_id, 0, block_n_id, 0);
|
||||
// return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
|
||||
},
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CDEBlockTransferCluster,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
|
||||
3, // index_t SrcVectorDim,
|
||||
3, // index_t DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
1, //ScatterDim
|
||||
true, //OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
// space filling curve for shuffled blockwise C/D/E
|
||||
constexpr auto sfc_cde_block =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
|
||||
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
|
||||
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float *p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
|
||||
if constexpr (IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
} else {
|
||||
const float *p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
|
||||
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
|
||||
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
cde_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights
|
||||
);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -1644,8 +2143,12 @@ struct GridwiseMoeGemm
|
||||
|
||||
// template <bool HasMainKBlockLoop,
|
||||
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
// bool IsInputGemm = true,
|
||||
// TailNumber TailNum = TailNumber::Odd>
|
||||
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
|
||||
// __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
|
||||
// const index_t* p_sorted_expert_ids,
|
||||
// const index_t* p_max_token_id,
|
||||
// const ADataType* p_a_grid,
|
||||
// const BDataType* p_b_grid,
|
||||
// DsGridPointer& p_ds_grid,
|
||||
// CDataType* p_c_grid,
|
||||
@@ -1656,37 +2159,7 @@ struct GridwiseMoeGemm
|
||||
// BElementwiseOperation b_element_op,
|
||||
// CElementwiseOperation c_element_op)
|
||||
// {
|
||||
// // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
// // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
// // p_a_grid,
|
||||
// // p_b_grid,
|
||||
// // p_ds_grid,
|
||||
// // p_c_grid,
|
||||
// // p_shared,
|
||||
// // p_shared1,
|
||||
// // problem,
|
||||
// // a_element_op,
|
||||
// // b_element_op,
|
||||
// // c_element_op,
|
||||
// // block_2_ctile_map);
|
||||
// }
|
||||
|
||||
// template <typename Block2CTileMap,
|
||||
// bool HasMainKBlockLoop,
|
||||
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
// TailNumber TailNum = TailNumber::Odd>
|
||||
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
|
||||
// const BDataType* p_b_grid,
|
||||
// DsGridPointer& p_ds_grid,
|
||||
// CDataType* p_c_grid,
|
||||
// void* p_shared,
|
||||
// void* p_shared1,
|
||||
// const Problem& problem,
|
||||
// AElementwiseOperation a_element_op,
|
||||
// BElementwiseOperation b_element_op,
|
||||
// CElementwiseOperation c_element_op,
|
||||
// const Block2CTileMap& block_2_ctile_map)
|
||||
// {
|
||||
// }
|
||||
};
|
||||
|
||||
|
||||
@@ -100,14 +100,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
|
||||
const ElementwiseOperation& element_op,
|
||||
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
|
||||
const ElementwiseOperation& element_op)
|
||||
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
|
||||
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
|
||||
element_op_(element_op),
|
||||
scatter_offsets_(scatter_offsets),
|
||||
scatter_weights_(scatter_weights)
|
||||
element_op_(element_op)
|
||||
{
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
@@ -158,6 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
@@ -181,9 +178,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number<iScatter>{}));
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number<iScatter>{}));
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); });
|
||||
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights(Number<iScatter>{}); });
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
@@ -418,6 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
@@ -430,13 +428,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
if constexpr (OutputScatter)
|
||||
{
|
||||
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
scatter_offset = scatter_offsets_(Number<iScatter>{});
|
||||
scatter_offset = scatter_offsets(Number<iScatter>{});
|
||||
}
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
|
||||
@@ -449,11 +447,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
dst_offset,
|
||||
is_dst_valid,
|
||||
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
// if(1) {
|
||||
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
|
||||
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
|
||||
// static_for<0, 1, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
|
||||
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
@@ -509,10 +507,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs)
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs);
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
auto adjusted_step_idx_scatter = [&]()
|
||||
{
|
||||
Index step_;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
|
||||
});
|
||||
|
||||
return step_;
|
||||
}
|
||||
();
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
|
||||
|
||||
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
|
||||
}
|
||||
@@ -709,8 +719,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
const ElementwiseOperation element_op_;
|
||||
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
|
||||
StaticallyIndexedArray<float, scatter_num> scatter_weights_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -11,6 +11,13 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
inline __device__ int amd_assembly_and_b32(int a, int b)
|
||||
{
|
||||
int c;
|
||||
asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
|
||||
{
|
||||
int c;
|
||||
@@ -32,7 +39,24 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a)
|
||||
inline __device__ float amd_assemble_cvt_f32_i4(int b)
|
||||
{
|
||||
float a;
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b));
|
||||
return a;
|
||||
}
|
||||
|
||||
inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
|
||||
{
|
||||
f8x4_t a;
|
||||
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n"
|
||||
"v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
|
||||
: "=v"(a)
|
||||
: "v"(b0), "v"(b1), "v"(b2), "v"(b3));
|
||||
return a;
|
||||
}
|
||||
|
||||
inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
{
|
||||
uint32_t i4x8 = static_cast<uint32_t>(a);
|
||||
uint32_t fp8x4_0;
|
||||
@@ -60,14 +84,7 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a)
|
||||
[v_src] "+v"(i4x8)
|
||||
:);
|
||||
|
||||
union
|
||||
{
|
||||
uint64_t as_uint64;
|
||||
f8x8_t as_f8x8;
|
||||
} convert;
|
||||
|
||||
convert.as_uint64 = (static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0;
|
||||
return convert.as_f8x8;
|
||||
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
|
||||
@@ -1191,11 +1191,15 @@ struct vector_type<T, 32, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
StaticallyIndexedArray<d8_t, 4> d8x4_;
|
||||
StaticallyIndexedArray<d16_t, 2> d16x2_;
|
||||
StaticallyIndexedArray<d32_t, 1> d32x1_;
|
||||
} data_;
|
||||
} data_ = { .d32_ = {0} };
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type() { }
|
||||
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { }
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
// __host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
// __host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
|
||||
@@ -877,19 +877,11 @@ struct MoeSortingKernel
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
//temp hack ptr for expert tile cnt
|
||||
p_total_tokens_post_pad[1] = 0;
|
||||
}
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
int e_start = smem_cumsum(i_e);
|
||||
int e_end = smem_cumsum(i_e + 1);
|
||||
|
||||
//temp hack ptr for expert tile cnt
|
||||
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + i_e;
|
||||
p_sorted_expert_cnts[1] = unit_size_mdiv.div(e_end);
|
||||
|
||||
int expert_id = [&]() {
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -994,11 +986,18 @@ struct MoeSortingKernel
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
//temp hack ptr for expert tile cnt
|
||||
p_total_tokens_post_pad[1] = 0;
|
||||
}
|
||||
// add the skip number
|
||||
for(int eid = tid; eid < num_experts; eid += block_size)
|
||||
{
|
||||
//temp hack ptr for expert tile cnt
|
||||
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + eid;
|
||||
int e_start = smem_cumsum(eid);
|
||||
int e_end = smem_cumdup(eid + 1);
|
||||
p_sorted_expert_cnts[1] = unit_size_mdiv.div(e_end);
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end) // skip zero token expert
|
||||
@@ -1682,6 +1681,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
|
||||
if(position < kargs.num_experts)
|
||||
{
|
||||
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + position;//temp mock for p_sorted_expert_cnts, fixme:felix
|
||||
p_sorted_expert_cnts[0] = out_0;
|
||||
p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor;
|
||||
}
|
||||
|
||||
@@ -1710,6 +1711,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
|
||||
p_total_tokens_post_pad[0] = total_tokens_post_pad;
|
||||
p_total_tokens_post_pad[kargs.num_experts+1] = prev_cumsum_a; //temp mock for p_sorted_expert_cnts, fixme:felix
|
||||
p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +91,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -106,7 +110,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -106,7 +106,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -120,7 +124,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -198,6 +206,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
return str.str();
|
||||
}
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
static float i4_to_f32_gfx9(uint8_t i4)
|
||||
{
|
||||
static std::unordered_map<uint8_t, float> u = {{0b1000, -0.5000f},
|
||||
@@ -219,6 +228,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
|
||||
return u[i4];
|
||||
}
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
|
||||
Reference in New Issue
Block a user