tempsave, token = 2 failed, need to debug

This commit is contained in:
mtgu0705
2025-05-23 06:43:25 -05:00
parent 5ea3fe488d
commit b44af02096
7 changed files with 334 additions and 128 deletions

View File

@@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -151,14 +151,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
ScaleBlockSize, 64,
MPerBlock, 32, KPerBlock,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
2, 2,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
@@ -170,14 +170,14 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t sorted_tile_num = 2;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t experts = 2;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -319,8 +319,8 @@ int main(int argc, char* argv[])
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
@@ -337,12 +337,26 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 8:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
@@ -404,25 +418,40 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
#if 0
#if 1
printf("a0_t_k_k:\n");
// for(int t = 0; t < tokens; ++t)
// {
// for(int tk = 0; tk < topk; ++tk)
// {
// for(int k = 0; k < K; ++k)
// {
// auto f4x2 = a0_t_k_k(t, tk, k).data;
// if(k % 2 == 0)
// {
// ck::f4_t f4 = (f4x2 >> 4) & 0xf;
// printf("%.2f ", ck::type_convert<float>(f4));
// }
// else
// {
// ck::f4_t f4 = (f4x2 >> 0) & 0xf;
// printf("%.2f ", ck::type_convert<float>(f4));
// }
// }
// printf("\n");
// }
// printf("\n");
// }
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
{
for(int k = 0; k < K; ++k)
for(int k = 0; k < K;)
{
auto f4x2 = a0_t_k_k(t, tk, k).data;
if(k % 2 == 0)
{
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
printf("0x%08x ",
*(reinterpret_cast<const uint32_t*>(&(a0_t_k_k(t, tk, k).data)))); // 4 bytes
k += 8;
}
printf("\n");
}
@@ -464,23 +493,37 @@ int main(int argc, char* argv[])
}
printf("b0_e_n_k:\n");
// for(int e = 0; e < experts; ++e)
// {
// for(int n = 0; n < N; ++n)
// {
// for(int k = 0; k < K; ++k)
// {
// auto f4x2 = b0_e_n_k(e, k, n).data;
// if(k % 2 == 0)
// {
// ck::f4_t f4 = f4x2 >> 4 & 0xf;
// printf("%.2f ", ck::type_convert<float>(f4));
// }
// else
// {
// ck::f4_t f4 = f4x2 >> 0 & 0xf;
// printf("%.2f ", ck::type_convert<float>(f4));
// }
// }
// printf("\n");
// }
// printf("\n");
// }
for(int e = 0; e < experts; ++e)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
for(int k = 0; k < K;)
{
auto f4x2 = b0_e_n_k(e, k, n).data;
if(k % 2 == 0)
{
ck::f4_t f4 = f4x2 >> 4 & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = f4x2 >> 0 & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
printf("0x%08x ",
*(reinterpret_cast<const uint32_t*>(&(b0_e_n_k(e, k, n).data)))); // 4 bytes
k += 8;
}
printf("\n");
}
@@ -509,6 +552,7 @@ int main(int argc, char* argv[])
printf("%.2f ", ck::type_convert<float>(d2_e_n(i, n)));
}
}
printf("\n");
#endif
// do GEMM
@@ -625,7 +669,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
#if 0
#if 1
printf("e_t_n_device_result:\n");
for(int t = 0; t < tokens; ++t)
{

View File

@@ -472,6 +472,25 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
#if 0
printf("blkx: %u, blky: %u, tid: %u, a_block_bufs(0):<0x%08x, 0x%08x, 0x%08x, 0x%08x, "
"0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[0].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[16].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[32].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[48].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[64].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[80].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[96].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[112].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[1024 + 0].data))),
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[1024 + 112].data))));
#endif
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
@@ -1080,11 +1099,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
#if 0
#if 1
printf(
"blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: "
"%d, ikxdl: %d, a_thread_vec=<%.2f, %.2f, %.2f, %.2f>, "
"b_thread_vec=<%.2f, %.2f, %.2f, %.2f>, a_scale=%08x, "
"%d, ikxdl: %d, a_thread_vec=<%08x, %08x, %08x, %08x>, "
"b_thread_vec=<%08x, %08x, %08x, %08x>, a_scale=%08x, "
"b_scale=%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
blockIdx.x,
blockIdx.y,
@@ -1092,38 +1111,22 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
imxdl.value,
inxdl.value,
ikxdl.value,
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<0>{}]
.unpack(Number<0>{})),
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<0>{}]
.unpack(Number<1>{})),
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<1>{}]
.unpack(Number<0>{})),
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<1>{}]
.unpack(Number<1>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<0>{}]
.unpack(Number<0>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<0>{}]
.unpack(Number<1>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<1>{}]
.unpack(Number<0>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<1>{}]
.unpack(Number<1>{})),
*(reinterpret_cast<const uint32_t*>(&(
a_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
a_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
a_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
a_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
b_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
b_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
b_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
b_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
a_scale_thread_vec
.template AsType<AScaleDataType>()[Number<0>{}]))),

View File

@@ -68,15 +68,20 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto wave_thread_cluster_lengths = Sequence<ThreadClusterLengths{}.At(I0), ThreadClusterLengths{}.At(I1)*64/ThreadGroup::GetNumOfThread(),1>{};
static constexpr auto wave_cluster_lengths = Sequence<1, ThreadGroup::GetNumOfThread()/64, 1>{};
static constexpr auto wave_thread_cluster_lengths =
Sequence<ThreadClusterLengths{}.At(I0),
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
1>{};
static constexpr auto wave_cluster_lengths =
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size= wave_thread_cluster_lengths*thread_single_load_size;
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static __device__ constexpr bool AreThreadClusterLengthsValid()
@@ -171,17 +176,17 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto wave_cluster_idx =
wave_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()/64));
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -240,6 +245,22 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid);
#if 0
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
printf("blkx: %u, blky: %u, tid: %u, src: %d, b_dst_offset: "
"%d, b_dst_buffer=<%02x, %02x, %02x, %02x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
src_offset,
dst_offset,
static_cast<uint8_t>(dst_buf[dst_offset].data),
static_cast<uint8_t>(dst_buf[dst_offset + 16].data),
static_cast<uint8_t>(dst_buf[dst_offset + 32].data),
static_cast<uint8_t>(dst_buf[dst_offset + 48].data));
#endif
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
@@ -292,6 +313,23 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
});
});
#if 0
block_sync_lds();
if(threadIdx.x == 0)
{
// Print the contents of the destination buffer.
printf("blkx: %u, blky: %u, tid: %u, B_dst_buffer=<%02x, %02x, %02x, %02x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
static_cast<uint8_t>(dst_buf[Number<0>{}].data),
static_cast<uint8_t>(dst_buf[Number<16>{}].data),
static_cast<uint8_t>(dst_buf[Number<32>{}].data),
static_cast<uint8_t>(dst_buf[Number<48>{}].data));
}
#endif
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc);
}

View File

@@ -66,15 +66,27 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto wave_thread_cluster_lengths =
Sequence<ThreadClusterLengths{}.At(I0),
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
1>{};
static constexpr auto wave_cluster_lengths =
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// CK_PRINT<decltype(thread_single_load_size)>();
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
@@ -172,10 +184,16 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -188,6 +206,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
return idx;
}();
// CK_PRINT<decltype(adjusted_src_origin_idx)>();
// CK_PRINT<decltype(src_slice_origin_idx)>();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
src_slice_origin_ = adjusted_src_origin_idx;
}
@@ -230,20 +251,45 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// CK_PRINT<decltype(dst_access_lengths), decltype(ordered_dst_access_idx)>();
auto gather_offset = gather_offsets_(Number<GatherDim>{});
const auto src_offset = src_coord_.GetOffset() + gather_offset;
const auto dst_offset = dst_coord_.GetOffset();
// printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(),
// src_coord_.GetOffset(), dst_coord_.GetOffset());
IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number<GatherDim>{}]];
const IndexType src_offset = src_coord_.GetOffset() + gather_offset;
const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// Leave the HW for oob checking
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
// src_coord_);
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid);
dst_buf, src_offset, dst_offset, true);
constexpr auto move_on_dim = [&]() constexpr
#if 1
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
printf("blkx: %u, blky: %u, tid: %u, red_id: %d src: %d (cal: %d, gather: %d), "
"dst_offset: "
"%d, a_dst_buffer=<0x%08x, 0x%08x, 0x%08x, 0x%08x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
static_cast<int>(ordered_dst_access_idx[Number<GatherDim>{}]),
src_offset,
src_coord_.GetOffset(),
gather_offset,
dst_offset,
// *(reinterpret_cast<const uint32_t*>(&(dst_buf[dst_offset + 0].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 16 * threadIdx.x].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 16 * threadIdx.x].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 32 * threadIdx.x].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 48 * threadIdx.x].data))));
#endif
constexpr auto move_src_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
@@ -260,6 +306,22 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
}
();
constexpr auto move_dst_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
@@ -280,22 +342,58 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
}();
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
// Move the source coordinate.
if constexpr(move_src_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
}
else
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
}
}
// Move the destination coordinate.
if constexpr(move_dst_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
}
else
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
}
}
});
});
#if 0
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
if(threadIdx.x == 0)
{
// Print the contents of the destination buffer.
printf("blkx: %u, blky: %u, tid: %u, a_dst_buf_offset=<%d, %d, %d, %d>, "
"a_dst_buffer=<%02x, %02x, %02x, %02x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
0,
16,
32,
48,
static_cast<uint8_t>(dst_buf[Number<0>{}].data),
static_cast<uint8_t>(dst_buf[Number<16>{}].data),
static_cast<uint8_t>(dst_buf[Number<32>{}].data),
static_cast<uint8_t>(dst_buf[Number<48>{}].data));
}
#endif
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc);
}
@@ -325,6 +423,8 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;

View File

@@ -1387,6 +1387,7 @@ struct GridwiseMoeGemmMXBNS
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = a_element_op;
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
@@ -1657,35 +1658,22 @@ struct GridwiseMoeGemmMXBNS
p_b_grid_up + expert_id * expert_stride,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_blockwise_copy_up =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -2167,6 +2155,7 @@ struct GridwiseMoeGemmMXBNS
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = a_element_op;
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
@@ -2253,7 +2242,7 @@ struct GridwiseMoeGemmMXBNS
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
#if 0
#if 1
printf("blkx: %u, blky: %u, tidx: %u, token_pos: %d, gather_offsets:<%d, %d, %d, %d>\n",
blockIdx.x,
blockIdx.y,

View File

@@ -1504,6 +1504,23 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
#if 1
printf("blky: %u, tid: %u, src_offset: %d, repeat_id: %d, dst_tmp_vec=<0x%08x, "
"0x%08x, 0x%08x, "
"0x%08x\n",
blockIdx.y,
threadIdx.x,
static_cast<int>(ordered_access_idx[Number<1>{}]),
src_data_coord.GetOffset(),
*(reinterpret_cast<const uint32_t*>(
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<3>{}]))));
#endif
}
});
}

View File

@@ -202,6 +202,21 @@ struct DynamicBuffer
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
#if 0
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
// {
// printf("DirectCopyToLds: src_offset=%d, dst_offset=%d\n", src_offset, dst_offset);
// }
printf("blkx: %u, blky: %u, tid: %u, src_offset: %d, dst_offset: %d, sizeof(src_offset): "
"%lu, sizeof(dst_offset): %lu\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
src_offset,
dst_offset,
sizeof(src_offset),
sizeof(dst_offset));
#endif
amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
src_offset,
dst_buf.p_data_,