mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
temp save
This commit is contained in:
@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// clang-format on
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 16;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -163,14 +163,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 64,
|
||||
MPerBlock, 16, 128,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, 128,
|
||||
32, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
|
||||
8, 2,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
@@ -183,14 +183,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 2;
|
||||
constexpr ck::index_t valid_tile_num = 2;
|
||||
constexpr ck::index_t sorted_tile_num = 8;
|
||||
constexpr ck::index_t valid_tile_num = 8;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 2;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -341,6 +341,24 @@ int main(int argc, char* argv[])
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
@@ -378,7 +396,7 @@ int main(int argc, char* argv[])
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
printf("a0_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
|
||||
@@ -318,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize));
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
@@ -212,6 +212,11 @@ struct GridwiseMoeGemmMX
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave;
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(KPack * mfma_selector::selected_mfma.num_input_blks) / ScaleBlockSize;
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / mfma_selector::selected_mfma.num_input_blks;
|
||||
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
|
||||
@@ -1246,9 +1251,12 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize),
|
||||
1),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1, 1));
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize) /
|
||||
ScalesPerXdlopsRunPerThread,
|
||||
ScalesPerXdlopsRunPerThread),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize),
|
||||
ScalesPerXdlopsRunPerThread,
|
||||
1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.K, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
|
||||
@@ -1460,7 +1468,7 @@ struct GridwiseMoeGemmMX
|
||||
true,
|
||||
MXdlPerWave,
|
||||
KRepeat>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(0, 0, thread_offset_k), scale_gather_offsets);
|
||||
a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets);
|
||||
|
||||
// B scale load
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
|
||||
|
||||
@@ -529,8 +529,8 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
|
||||
// loop over tensor and copy
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeate
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeat
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) { // KRepeat
|
||||
constexpr auto current_dst_origin =
|
||||
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, k0, 0);
|
||||
MoveSrcSliceWindow(src_desc, make_multi_index(0, 0, 0));
|
||||
@@ -584,9 +584,14 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
|
||||
src_coord_,
|
||||
make_tensor_coordinate_step(src_desc, forward_step));
|
||||
}
|
||||
|
||||
MoveSrcSliceWindow(
|
||||
src_desc,
|
||||
make_multi_index(
|
||||
0, 4, 0)); // hacky fix: 4 means xdlops_gemm.KPerXdlops / ScaleBlockSize
|
||||
});
|
||||
});
|
||||
MoveSrcSliceWindow(src_desc, make_multi_index(0, -KRepeat, 0));
|
||||
MoveSrcSliceWindow(src_desc, make_multi_index(0, -(KRepeat * 4), 0));
|
||||
});
|
||||
|
||||
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
|
||||
|
||||
@@ -762,7 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
@@ -788,9 +788,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
"v"(scale_b));
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
printf("bidx: %u, bidy: %u, tid: %u, A: %08x, %08x, %08x, %08x,"
|
||||
"B:%08x, %08x, %08x, %08x, a_scale: %08x, b_scale: %08x, "
|
||||
"B:%08x, %08x, %08x, %08x, a_scale: %.f, b_scale: %.f, "
|
||||
"reg_c: %f, %f, %f, %f\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
@@ -803,8 +803,10 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
bit_cast<uint32_t>(arg_b[1]),
|
||||
bit_cast<uint32_t>(arg_b[2]),
|
||||
bit_cast<uint32_t>(arg_b[3]),
|
||||
*(reinterpret_cast<const uint32_t*>(&(scale_a))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(scale_b))),
|
||||
// *(reinterpret_cast<const uint32_t*>(&(scale_a))),
|
||||
// *(reinterpret_cast<const uint32_t*>(&(scale_b))),
|
||||
type_convert<float>(scale_a),
|
||||
type_convert<float>(scale_b),
|
||||
reg_c.template AsType<float>()[Number<0>{}],
|
||||
reg_c.template AsType<float>()[Number<1>{}],
|
||||
reg_c.template AsType<float>()[Number<2>{}],
|
||||
|
||||
@@ -89,10 +89,6 @@ struct ReferenceMoeMXGemm2 : public device::BaseOperator
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
|
||||
const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1];
|
||||
if(m == 0 && n == 0)
|
||||
{
|
||||
printf("SCALE_BLOCK: %d\n", SCALE_BLOCK);
|
||||
}
|
||||
AccDataType v_acc{0};
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
|
||||
Reference in New Issue
Block a user