mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
tempsave, token = 2 failed, need to debug
This commit is contained in:
@@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -151,14 +151,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
ScaleBlockSize, 64,
|
||||
MPerBlock, 32, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
2, 2,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
@@ -170,14 +170,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t sorted_tile_num = 2;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t experts = 2;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -319,8 +319,8 @@ int main(int argc, char* argv[])
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
@@ -337,12 +337,26 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 8:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
@@ -404,25 +418,40 @@ int main(int argc, char* argv[])
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
printf("a0_t_k_k:\n");
|
||||
// for(int t = 0; t < tokens; ++t)
|
||||
// {
|
||||
// for(int tk = 0; tk < topk; ++tk)
|
||||
// {
|
||||
// for(int k = 0; k < K; ++k)
|
||||
// {
|
||||
// auto f4x2 = a0_t_k_k(t, tk, k).data;
|
||||
// if(k % 2 == 0)
|
||||
// {
|
||||
// ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
// printf("%.2f ", ck::type_convert<float>(f4));
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
// printf("%.2f ", ck::type_convert<float>(f4));
|
||||
// }
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(int k = 0; k < K;)
|
||||
{
|
||||
auto f4x2 = a0_t_k_k(t, tk, k).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
printf("0x%08x ",
|
||||
*(reinterpret_cast<const uint32_t*>(&(a0_t_k_k(t, tk, k).data)))); // 4 bytes
|
||||
k += 8;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -464,23 +493,37 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
printf("b0_e_n_k:\n");
|
||||
// for(int e = 0; e < experts; ++e)
|
||||
// {
|
||||
// for(int n = 0; n < N; ++n)
|
||||
// {
|
||||
// for(int k = 0; k < K; ++k)
|
||||
// {
|
||||
// auto f4x2 = b0_e_n_k(e, k, n).data;
|
||||
// if(k % 2 == 0)
|
||||
// {
|
||||
// ck::f4_t f4 = f4x2 >> 4 & 0xf;
|
||||
// printf("%.2f ", ck::type_convert<float>(f4));
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// ck::f4_t f4 = f4x2 >> 0 & 0xf;
|
||||
// printf("%.2f ", ck::type_convert<float>(f4));
|
||||
// }
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(int k = 0; k < K;)
|
||||
{
|
||||
auto f4x2 = b0_e_n_k(e, k, n).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
ck::f4_t f4 = f4x2 >> 4 & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::f4_t f4 = f4x2 >> 0 & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
printf("0x%08x ",
|
||||
*(reinterpret_cast<const uint32_t*>(&(b0_e_n_k(e, k, n).data)))); // 4 bytes
|
||||
k += 8;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -509,6 +552,7 @@ int main(int argc, char* argv[])
|
||||
printf("%.2f ", ck::type_convert<float>(d2_e_n(i, n)));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
// do GEMM
|
||||
@@ -625,7 +669,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
printf("e_t_n_device_result:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
|
||||
@@ -472,6 +472,25 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
#if 0
|
||||
printf("blkx: %u, blky: %u, tid: %u, a_block_bufs(0):<0x%08x, 0x%08x, 0x%08x, 0x%08x, "
|
||||
"0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x, 0x%08x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[0].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[16].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[32].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[48].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[64].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[80].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[96].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[112].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[1024 + 0].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_block_bufs(I0)[1024 + 112].data))));
|
||||
|
||||
#endif
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
@@ -1080,11 +1099,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
#if 0
|
||||
#if 1
|
||||
printf(
|
||||
"blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: "
|
||||
"%d, ikxdl: %d, a_thread_vec=<%.2f, %.2f, %.2f, %.2f>, "
|
||||
"b_thread_vec=<%.2f, %.2f, %.2f, %.2f>, a_scale=%08x, "
|
||||
"%d, ikxdl: %d, a_thread_vec=<%08x, %08x, %08x, %08x>, "
|
||||
"b_thread_vec=<%08x, %08x, %08x, %08x>, a_scale=%08x, "
|
||||
"b_scale=%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
@@ -1092,38 +1111,22 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
imxdl.value,
|
||||
inxdl.value,
|
||||
ikxdl.value,
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<0>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<0>{}]
|
||||
.unpack(Number<1>{})),
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<1>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<1>{}]
|
||||
.unpack(Number<1>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<0>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<0>{}]
|
||||
.unpack(Number<1>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<1>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<1>{}]
|
||||
.unpack(Number<1>{})),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
a_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
a_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
a_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
a_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
b_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
b_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
b_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
b_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
a_scale_thread_vec
|
||||
.template AsType<AScaleDataType>()[Number<0>{}]))),
|
||||
|
||||
@@ -68,15 +68,20 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
static constexpr auto wave_thread_cluster_lengths = Sequence<ThreadClusterLengths{}.At(I0), ThreadClusterLengths{}.At(I1)*64/ThreadGroup::GetNumOfThread(),1>{};
|
||||
static constexpr auto wave_cluster_lengths = Sequence<1, ThreadGroup::GetNumOfThread()/64, 1>{};
|
||||
static constexpr auto wave_thread_cluster_lengths =
|
||||
Sequence<ThreadClusterLengths{}.At(I0),
|
||||
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
|
||||
1>{};
|
||||
static constexpr auto wave_cluster_lengths =
|
||||
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size= wave_thread_cluster_lengths*thread_single_load_size;
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
|
||||
static __device__ constexpr bool AreThreadClusterLengthsValid()
|
||||
@@ -171,17 +176,17 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto wave_cluster_idx =
|
||||
wave_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()/64));
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
|
||||
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
|
||||
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
|
||||
|
||||
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
|
||||
// We don't need threadwise offset for lds since it was calculate by HW
|
||||
// We still need input the wavewise offset.
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -240,6 +245,22 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, is_src_valid);
|
||||
|
||||
#if 0
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
printf("blkx: %u, blky: %u, tid: %u, src: %d, b_dst_offset: "
|
||||
"%d, b_dst_buffer=<%02x, %02x, %02x, %02x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
src_offset,
|
||||
dst_offset,
|
||||
static_cast<uint8_t>(dst_buf[dst_offset].data),
|
||||
static_cast<uint8_t>(dst_buf[dst_offset + 16].data),
|
||||
static_cast<uint8_t>(dst_buf[dst_offset + 32].data),
|
||||
static_cast<uint8_t>(dst_buf[dst_offset + 48].data));
|
||||
#endif
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
@@ -292,6 +313,23 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
});
|
||||
});
|
||||
|
||||
#if 0
|
||||
block_sync_lds();
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
// Print the contents of the destination buffer.
|
||||
printf("blkx: %u, blky: %u, tid: %u, B_dst_buffer=<%02x, %02x, %02x, %02x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
static_cast<uint8_t>(dst_buf[Number<0>{}].data),
|
||||
static_cast<uint8_t>(dst_buf[Number<16>{}].data),
|
||||
static_cast<uint8_t>(dst_buf[Number<32>{}].data),
|
||||
static_cast<uint8_t>(dst_buf[Number<48>{}].data));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Reset the destination slice since the entire buffer has been already filled.
|
||||
ResetDstSliceWindow(dst_desc);
|
||||
}
|
||||
|
||||
@@ -66,15 +66,27 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
static constexpr auto wave_thread_cluster_lengths =
|
||||
Sequence<ThreadClusterLengths{}.At(I0),
|
||||
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
|
||||
1>{};
|
||||
static constexpr auto wave_cluster_lengths =
|
||||
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
// CK_PRINT<decltype(thread_single_load_size)>();
|
||||
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
|
||||
@@ -172,10 +184,16 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
|
||||
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
|
||||
|
||||
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
|
||||
// We don't need threadwise offset for lds since it was calculate by HW
|
||||
// We still need input the wavewise offset.
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -188,6 +206,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
return idx;
|
||||
}();
|
||||
|
||||
// CK_PRINT<decltype(adjusted_src_origin_idx)>();
|
||||
// CK_PRINT<decltype(src_slice_origin_idx)>();
|
||||
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
|
||||
src_slice_origin_ = adjusted_src_origin_idx;
|
||||
}
|
||||
@@ -230,20 +251,45 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
|
||||
// Loop over the destination block and copy data.
|
||||
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// CK_PRINT<decltype(dst_access_lengths), decltype(ordered_dst_access_idx)>();
|
||||
auto gather_offset = gather_offsets_(Number<GatherDim>{});
|
||||
const auto src_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const auto dst_offset = dst_coord_.GetOffset();
|
||||
// printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(),
|
||||
// src_coord_.GetOffset(), dst_coord_.GetOffset());
|
||||
IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number<GatherDim>{}]];
|
||||
const IndexType src_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
|
||||
|
||||
// Check if src data is not in the logic padding area.
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
// Leave the HW for oob checking
|
||||
// const bool is_src_valid =
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
// src_coord_);
|
||||
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, is_src_valid);
|
||||
dst_buf, src_offset, dst_offset, true);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
#if 1
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
printf("blkx: %u, blky: %u, tid: %u, red_id: %d src: %d (cal: %d, gather: %d), "
|
||||
"dst_offset: "
|
||||
"%d, a_dst_buffer=<0x%08x, 0x%08x, 0x%08x, 0x%08x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
static_cast<int>(ordered_dst_access_idx[Number<GatherDim>{}]),
|
||||
src_offset,
|
||||
src_coord_.GetOffset(),
|
||||
gather_offset,
|
||||
dst_offset,
|
||||
// *(reinterpret_cast<const uint32_t*>(&(dst_buf[dst_offset + 0].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 16 * threadIdx.x].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 16 * threadIdx.x].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 32 * threadIdx.x].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 48 * threadIdx.x].data))));
|
||||
#endif
|
||||
|
||||
constexpr auto move_src_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
@@ -260,6 +306,22 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
}
|
||||
();
|
||||
|
||||
constexpr auto move_dst_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// Decide whether to move forward or backward.
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
@@ -280,22 +342,58 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
}();
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
// Move the source coordinate.
|
||||
if constexpr(move_src_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the destination coordinate.
|
||||
if constexpr(move_dst_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
#if 0
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
// Print the contents of the destination buffer.
|
||||
printf("blkx: %u, blky: %u, tid: %u, a_dst_buf_offset=<%d, %d, %d, %d>, "
|
||||
"a_dst_buffer=<%02x, %02x, %02x, %02x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
0,
|
||||
16,
|
||||
32,
|
||||
48,
|
||||
static_cast<uint8_t>(dst_buf[Number<0>{}].data),
|
||||
static_cast<uint8_t>(dst_buf[Number<16>{}].data),
|
||||
static_cast<uint8_t>(dst_buf[Number<32>{}].data),
|
||||
static_cast<uint8_t>(dst_buf[Number<48>{}].data));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Reset the destination slice since the entire buffer has been already filled.
|
||||
ResetDstSliceWindow(dst_desc);
|
||||
}
|
||||
@@ -325,6 +423,8 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
static constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -1387,6 +1387,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
@@ -1657,35 +1658,22 @@ struct GridwiseMoeGemmMXBNS
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto b_blockwise_copy_up =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -2167,6 +2155,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
@@ -2253,7 +2242,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
printf("blkx: %u, blky: %u, tidx: %u, token_pos: %d, gather_offsets:<%d, %d, %d, %d>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
|
||||
@@ -1504,6 +1504,23 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
#if 1
|
||||
printf("blky: %u, tid: %u, src_offset: %d, repeat_id: %d, dst_tmp_vec=<0x%08x, "
|
||||
"0x%08x, 0x%08x, "
|
||||
"0x%08x\n",
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
static_cast<int>(ordered_access_idx[Number<1>{}]),
|
||||
src_data_coord.GetOffset(),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_tmp_vector.template AsType<f4x8_t>()[Number<3>{}]))));
|
||||
#endif
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -202,6 +202,21 @@ struct DynamicBuffer
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"Destination data must be stored in an LDS memory buffer.");
|
||||
|
||||
#if 0
|
||||
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
// {
|
||||
// printf("DirectCopyToLds: src_offset=%d, dst_offset=%d\n", src_offset, dst_offset);
|
||||
// }
|
||||
printf("blkx: %u, blky: %u, tid: %u, src_offset: %d, dst_offset: %d, sizeof(src_offset): "
|
||||
"%lu, sizeof(dst_offset): %lu\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
src_offset,
|
||||
dst_offset,
|
||||
sizeof(src_offset),
|
||||
sizeof(dst_offset));
|
||||
#endif
|
||||
amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
|
||||
src_offset,
|
||||
dst_buf.p_data_,
|
||||
|
||||
Reference in New Issue
Block a user