mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
16x16x128 input size blockscale function passed
This commit is contained in:
@@ -315,18 +315,32 @@ int main(int argc, char* argv[])
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
// a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1.0, 1.0});
|
||||
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.0, 1.0});
|
||||
ck::utils::FillConstant<A0DataType>{ck::type_convert<A0DataType>(ck::float2_t(1.0f))}(
|
||||
a0_t_k_k);
|
||||
ck::utils::FillConstant<B0DataType>{ck::type_convert<B0DataType>(ck::float2_t(1.0f))}(
|
||||
b0_e_n_k);
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
@@ -372,13 +386,16 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
auto f4x2 = a0_t_k_k(t, tk, k).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data >> 4 & 0xf));
|
||||
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
printf("%f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data & 0xf));
|
||||
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
printf("%f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
@@ -407,13 +424,16 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
auto f4x2 = b0_e_n_k(e, k, n).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data >> 4 & 0xf));
|
||||
ck::f4_t f4 = f4x2 >> 4 & 0xf;
|
||||
printf("%f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data & 0xf));
|
||||
ck::f4_t f4 = f4x2 >> 0 & 0xf;
|
||||
printf("%f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
@@ -279,34 +279,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 0
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
// auto a_scale_thread_buf_copy =
|
||||
// make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
// a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
// a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
// a_scale_grid_buf,
|
||||
// a_scale_thread_desc_copy,
|
||||
// make_tuple(I0, I0),
|
||||
// a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_bufs(I0)(Number<a_scale_offset>{}) =
|
||||
type_convert<AScaleDataType>(1.0f);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
@@ -314,17 +295,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
// auto b_scale_thread_buf_copy =
|
||||
// make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
// b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
// b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
// b_scale_grid_buf,
|
||||
// b_scale_thread_desc_copy,
|
||||
// make_tuple(I0, I0),
|
||||
// b_scale_thread_buf_copy);
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
type_convert<BScaleDataType>(1.0f);
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
@@ -337,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
make_multi_index(0, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -349,34 +330,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
// auto a_scale_thread_buf_copy =
|
||||
// make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
// a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
// a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
// a_scale_grid_buf,
|
||||
// a_scale_thread_desc_copy,
|
||||
// make_tuple(I0, I0),
|
||||
// a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_bufs(I1)(Number<a_scale_offset>{}) =
|
||||
type_convert<AScaleDataType>(1.0f);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
@@ -384,17 +346,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
// auto b_scale_thread_buf_copy =
|
||||
// make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
// b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
// b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
// b_scale_grid_buf,
|
||||
// b_scale_thread_desc_copy,
|
||||
// make_tuple(I0, I0),
|
||||
// b_scale_thread_buf_copy);
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
type_convert<BScaleDataType>(1.0f);
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
@@ -538,35 +500,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
// auto a_scale_thread_buf_copy =
|
||||
// make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
// a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
// a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
// a_scale_grid_buf,
|
||||
// a_scale_thread_desc_copy,
|
||||
// make_tuple(I0, I0),
|
||||
// a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_bufs(mfma_reg_buf)(Number<a_scale_offset>{}) =
|
||||
type_convert<AScaleDataType>(1.0f);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
@@ -574,17 +516,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
// auto b_scale_thread_buf_copy =
|
||||
// make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
// b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
// b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
// b_scale_grid_buf,
|
||||
// b_scale_thread_desc_copy,
|
||||
// make_tuple(I0, I0),
|
||||
// b_scale_thread_buf_copy);
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
type_convert<BScaleDataType>(1.0f);
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
@@ -211,6 +211,7 @@ struct GridwiseMoeGemmMX
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave;
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
|
||||
@@ -512,9 +513,7 @@ struct GridwiseMoeGemmMX
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWave, MPerXdl>(ABlockDesc_AK0_M_AK1{});
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_BK0_N_BK1>
|
||||
@@ -942,8 +941,6 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
@@ -1249,8 +1246,9 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize),
|
||||
1),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1, 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.K, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
|
||||
@@ -1431,20 +1429,40 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(
|
||||
a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k));
|
||||
// get each thread's offset int the scale tensor
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock;
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
|
||||
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
|
||||
const index_t fused_token =
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scale_gather_offsets(m0) =
|
||||
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize);
|
||||
});
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
|
||||
Sequence<1, 1, 1>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true,
|
||||
MXdlPerWave,
|
||||
KRepeat>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(0, 0, thread_offset_k), scale_gather_offsets);
|
||||
|
||||
// B scale load
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
@@ -1537,8 +1555,6 @@ struct GridwiseMoeGemmMX
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
@@ -2255,8 +2271,6 @@ struct GridwiseMoeGemmMX
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
@@ -424,6 +424,248 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
SrcCoord src_coord_;
|
||||
}; // namespace ck
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
index_t scale_gather_num,
|
||||
index_t KRepeat,
|
||||
bool InvalidElementAsNaN = false,
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2_gather
|
||||
{
|
||||
static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
|
||||
(!InvalidElementAsNaN),
|
||||
"Filling invalid element as NaN is only for floating point types");
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx,
|
||||
const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
|
||||
scale_gather_offsets_(scale_gather_offsets)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
|
||||
{
|
||||
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
auto adjusted_origin_idx = [&]() {
|
||||
Index idx;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
|
||||
|
||||
return idx;
|
||||
}();
|
||||
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
// DstDesc and dst_slice_origin_idx are known at compile-time
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
// loop over tensor and copy
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeate
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto current_dst_origin =
|
||||
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, k0, 0);
|
||||
MoveSrcSliceWindow(src_desc, make_multi_index(0, 0, 0));
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
|
||||
src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData,
|
||||
SrcScalarPerVector / PackedSize>::type::type;
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
src_coord_);
|
||||
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
|
||||
scale_gather_offsets_(gather_idx),
|
||||
is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
|
||||
src_data_idx + i * src_scalar_step_in_vector);
|
||||
constexpr auto full_dst_offset =
|
||||
dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
|
||||
|
||||
if constexpr(InvalidElementAsNaN)
|
||||
{
|
||||
dst_buf(full_dst_offset) =
|
||||
is_src_valid ? type_convert<DstData>(
|
||||
src_vector.template AsType<SrcData>()[i])
|
||||
: NumericLimits<DstData>::QuietNaN();
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_buf(Number<full_dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(src_desc,
|
||||
src_coord_,
|
||||
make_tensor_coordinate_step(src_desc, forward_step));
|
||||
}
|
||||
});
|
||||
});
|
||||
MoveSrcSliceWindow(src_desc, make_multi_index(0, -KRepeat, 0));
|
||||
});
|
||||
|
||||
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
|
||||
// blockIdx.y,
|
||||
// threadIdx.x,
|
||||
// dst_buf(Number<0>{}));
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_coord_;
|
||||
StaticallyIndexedArray<index_t, scale_gather_num> scale_gather_offsets_;
|
||||
}; // namespace ck
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
|
||||
@@ -762,7 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
|
||||
Reference in New Issue
Block a user