16x16x128 input size blockscale function passed

This commit is contained in:
mtgu0705
2025-05-14 03:20:59 -05:00
parent 7be8730247
commit 2700b217be
5 changed files with 357 additions and 139 deletions

View File

@@ -315,18 +315,32 @@ int main(int argc, char* argv[])
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
// a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1.0, 1.0});
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.0, 1.0});
ck::utils::FillConstant<A0DataType>{ck::type_convert<A0DataType>(ck::float2_t(1.0f))}(
a0_t_k_k);
ck::utils::FillConstant<B0DataType>{ck::type_convert<B0DataType>(ck::float2_t(1.0f))}(
b0_e_n_k);
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
@@ -372,13 +386,16 @@ int main(int argc, char* argv[])
{
for(int k = 0; k < K; ++k)
{
auto f4x2 = a0_t_k_k(t, tk, k).data;
if(k % 2 == 0)
{
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data >> 4 & 0xf));
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%f ", ck::type_convert<float>(f4));
}
else
{
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data & 0xf));
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%f ", ck::type_convert<float>(f4));
}
}
printf("\n");
@@ -407,13 +424,16 @@ int main(int argc, char* argv[])
{
for(int k = 0; k < K; ++k)
{
auto f4x2 = b0_e_n_k(e, k, n).data;
if(k % 2 == 0)
{
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data >> 4 & 0xf));
ck::f4_t f4 = f4x2 >> 4 & 0xf;
printf("%f ", ck::type_convert<float>(f4));
}
else
{
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data & 0xf));
ck::f4_t f4 = f4x2 >> 0 & 0xf;
printf("%f ", ck::type_convert<float>(f4));
}
}
printf("\n");

View File

@@ -279,34 +279,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales to buf 0
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
// auto a_scale_thread_buf_copy =
// make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
// a_scale_thread_desc_copy.GetElementSpaceSize());
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc_copy,
// make_tuple(I0, I0),
// a_scale_thread_buf_copy);
a_scale_thread_bufs(I0)(Number<a_scale_offset>{}) =
type_convert<AScaleDataType>(1.0f);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I0));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales to buf 0
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -314,17 +295,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
// auto b_scale_thread_buf_copy =
// make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
// b_scale_thread_desc_copy.GetElementSpaceSize());
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc_copy,
// make_tuple(I0, I0),
// b_scale_thread_buf_copy);
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
type_convert<BScaleDataType>(1.0f);
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
@@ -337,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
make_multi_index(0, ScalesPerKBlockSize));
__builtin_amdgcn_sched_barrier(0);
@@ -349,34 +330,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Prefetch a_scales to buf 1
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
// auto a_scale_thread_buf_copy =
// make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
// a_scale_thread_desc_copy.GetElementSpaceSize());
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc_copy,
// make_tuple(I0, I0),
// a_scale_thread_buf_copy);
a_scale_thread_bufs(I1)(Number<a_scale_offset>{}) =
type_convert<AScaleDataType>(1.0f);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I1));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales to buf 1
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -384,17 +346,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
// auto b_scale_thread_buf_copy =
// make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
// b_scale_thread_desc_copy.GetElementSpaceSize());
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc_copy,
// make_tuple(I0, I0),
// b_scale_thread_buf_copy);
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
type_convert<BScaleDataType>(1.0f);
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
@@ -538,35 +500,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
});
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
// auto a_scale_thread_buf_copy =
// make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
// a_scale_thread_desc_copy.GetElementSpaceSize());
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc_copy,
// make_tuple(I0, I0),
// a_scale_thread_buf_copy);
a_scale_thread_bufs(mfma_reg_buf)(Number<a_scale_offset>{}) =
type_convert<AScaleDataType>(1.0f);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(mfma_reg_buf));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -574,17 +516,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
// auto b_scale_thread_buf_copy =
// make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
// b_scale_thread_desc_copy.GetElementSpaceSize());
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc_copy,
// make_tuple(I0, I0),
// b_scale_thread_buf_copy);
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
type_convert<BScaleDataType>(1.0f);
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));

View File

@@ -211,6 +211,7 @@ struct GridwiseMoeGemmMX
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave;
// static constexpr index_t NumTokens = 1;
static constexpr index_t SortedTileSize = MPerBlock;
@@ -512,9 +513,7 @@ struct GridwiseMoeGemmMX
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWave, MPerXdl>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
@@ -942,8 +941,6 @@ struct GridwiseMoeGemmMX
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
@@ -1249,8 +1246,9 @@ struct GridwiseMoeGemmMX
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
math::integer_divide_ceil(problem.K, ScaleBlockSize)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
math::integer_divide_ceil(problem.K, ScaleBlockSize),
1),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1, 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.K, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
@@ -1431,20 +1429,40 @@ struct GridwiseMoeGemmMX
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(
a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k));
// get each thread's offset int the scale tensor
const index_t token_scale_pos = block_m_id * MPerBlock;
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
const index_t fused_token =
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m];
index_t token_offset = fused_token & 0xffffff;
if constexpr(!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scale_gather_offsets(m0) =
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize);
});
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather<
AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
Sequence<1, 1, 1>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true,
MXdlPerWave,
KRepeat>(
a_scale_grid_desc_am_ak, make_multi_index(0, 0, thread_offset_k), scale_gather_offsets);
// B scale load
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
auto b_scale_thread_copy =
@@ -1537,8 +1555,6 @@ struct GridwiseMoeGemmMX
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
@@ -2255,8 +2271,6 @@ struct GridwiseMoeGemmMX
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();

View File

@@ -424,6 +424,248 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcCoord src_coord_;
}; // namespace ck
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun,
index_t scale_gather_num,
index_t KRepeat,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2_gather
{
static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(
const SrcDesc& src_desc,
const Index& src_slice_origin_idx,
const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
scale_gather_offsets_(scale_gather_offsets)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
auto adjusted_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}(
[&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
return idx;
}();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeate
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto current_dst_origin =
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, k0, 0);
MoveSrcSliceWindow(src_desc, make_multi_index(0, 0, 0));
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
src_vector;
using src_vector_t =
typename vector_type_maker<SrcData,
SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
src_coord_);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
scale_gather_offsets_(gather_idx),
is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
src_data_idx + i * src_scalar_step_in_vector);
constexpr auto full_dst_offset =
dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
if constexpr(InvalidElementAsNaN)
{
dst_buf(full_dst_offset) =
is_src_valid ? type_convert<DstData>(
src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<full_dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(src_desc,
src_coord_,
make_tensor_coordinate_step(src_desc, forward_step));
}
});
});
MoveSrcSliceWindow(src_desc, make_multi_index(0, -KRepeat, 0));
});
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
// blockIdx.y,
// threadIdx.x,
// dst_buf(Number<0>{}));
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
StaticallyIndexedArray<index_t, scale_gather_num> scale_gather_offsets_;
}; // namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer

View File

@@ -762,7 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
#if 1
#if 0
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},