temp save

This commit is contained in:
mtgu0705
2025-05-14 08:13:47 -05:00
parent 2700b217be
commit 102151ebcf
6 changed files with 57 additions and 28 deletions

View File

@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// clang-format on
#else
static constexpr ck::index_t MPerBlock = 16;
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -163,14 +163,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 64,
MPerBlock, 16, 128,
ScaleBlockSize, 256,
MPerBlock, 128, 128,
32, 32,
16, 16,
1, 1,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
8, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
// clang-format on
#endif
@@ -183,14 +183,14 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 2;
constexpr ck::index_t valid_tile_num = 2;
constexpr ck::index_t sorted_tile_num = 8;
constexpr ck::index_t valid_tile_num = 8;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 2;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -341,6 +341,24 @@ int main(int argc, char* argv[])
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
@@ -378,7 +396,7 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
#if 1
#if 0
printf("a0_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{

View File

@@ -318,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize));
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
__builtin_amdgcn_sched_barrier(0);

View File

@@ -212,6 +212,11 @@ struct GridwiseMoeGemmMX
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave;
static constexpr auto ScalesPerXdlopsRun =
(KPack * mfma_selector::selected_mfma.num_input_blks) / ScaleBlockSize;
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / mfma_selector::selected_mfma.num_input_blks;
// static constexpr index_t NumTokens = 1;
static constexpr index_t SortedTileSize = MPerBlock;
@@ -1246,9 +1251,12 @@ struct GridwiseMoeGemmMX
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
math::integer_divide_ceil(problem.K, ScaleBlockSize),
1),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1, 1));
math::integer_divide_ceil(problem.K, ScaleBlockSize) /
ScalesPerXdlopsRunPerThread,
ScalesPerXdlopsRunPerThread),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize),
ScalesPerXdlopsRunPerThread,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.K, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
@@ -1460,7 +1468,7 @@ struct GridwiseMoeGemmMX
true,
MXdlPerWave,
KRepeat>(
a_scale_grid_desc_am_ak, make_multi_index(0, 0, thread_offset_k), scale_gather_offsets);
a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets);
// B scale load
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;

View File

@@ -529,8 +529,8 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeate
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeat
static_for<0, KRepeat, 1>{}([&](auto k0) { // KRepeat
constexpr auto current_dst_origin =
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, k0, 0);
MoveSrcSliceWindow(src_desc, make_multi_index(0, 0, 0));
@@ -584,9 +584,14 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
src_coord_,
make_tensor_coordinate_step(src_desc, forward_step));
}
MoveSrcSliceWindow(
src_desc,
make_multi_index(
0, 4, 0)); // hacky fix: 4 means xdlops_gemm.KPerXdlops / ScaleBlockSize
});
});
MoveSrcSliceWindow(src_desc, make_multi_index(0, -KRepeat, 0));
MoveSrcSliceWindow(src_desc, make_multi_index(0, -(KRepeat * 4), 0));
});
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",

View File

@@ -762,7 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
#if 0
#if 1
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
@@ -788,9 +788,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
"v"(scale_b));
#endif
#if 1
#if 0
printf("bidx: %u, bidy: %u, tid: %u, A: %08x, %08x, %08x, %08x,"
"B:%08x, %08x, %08x, %08x, a_scale: %08x, b_scale: %08x, "
"B:%08x, %08x, %08x, %08x, a_scale: %.f, b_scale: %.f, "
"reg_c: %f, %f, %f, %f\n",
blockIdx.x,
blockIdx.y,
@@ -803,8 +803,10 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
bit_cast<uint32_t>(arg_b[1]),
bit_cast<uint32_t>(arg_b[2]),
bit_cast<uint32_t>(arg_b[3]),
*(reinterpret_cast<const uint32_t*>(&(scale_a))),
*(reinterpret_cast<const uint32_t*>(&(scale_b))),
// *(reinterpret_cast<const uint32_t*>(&(scale_a))),
// *(reinterpret_cast<const uint32_t*>(&(scale_b))),
type_convert<float>(scale_a),
type_convert<float>(scale_b),
reg_c.template AsType<float>()[Number<0>{}],
reg_c.template AsType<float>()[Number<1>{}],
reg_c.template AsType<float>()[Number<2>{}],

View File

@@ -89,10 +89,6 @@ struct ReferenceMoeMXGemm2 : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1];
if(m == 0 && n == 0)
{
printf("SCALE_BLOCK: %d\n", SCALE_BLOCK);
}
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};