mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
debug 16x16 load
This commit is contained in:
@@ -131,13 +131,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>;
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
|
||||
32, 128, 128,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
//threadnum, mblock, nblock, kblock
|
||||
256, 32, 128, 128,
|
||||
// ak1, bk1
|
||||
8, 8,
|
||||
// mn_perxdl
|
||||
32, 32,
|
||||
// mn_xdlperwave
|
||||
1, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
@@ -162,7 +167,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 8192;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 8;
|
||||
ck::index_t sorted_tile_num = 1;
|
||||
ck::index_t sorted_tile_size = 32;
|
||||
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
|
||||
ck::index_t tokens = 32;
|
||||
|
||||
@@ -45,11 +45,12 @@ template <typename ThreadGroup,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
static constexpr index_t mod_num = ThreadClusterLengths{}.At(I0); // Dirty HACK FELIX, TODO fix
|
||||
using Index = MultiIndex<nDim>;
|
||||
// using GatherIndex = MultiIndex<gather_num>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8(
|
||||
const SrcDesc& src_desc,
|
||||
@@ -86,7 +87,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() % 8));
|
||||
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + src_thread_cluster_idx * thread_slice_lengths);
|
||||
|
||||
@@ -104,7 +105,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() % 8));
|
||||
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
|
||||
@@ -1127,16 +1127,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
|
||||
|
||||
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto MLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto KLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0) * ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
|
||||
constexpr auto MLoadRepeats = MPerBlock / MLoadThreads;
|
||||
static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads;
|
||||
StaticallyIndexedArray<index_t, MLoadRepeats> token_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, MLoadRepeats, 1>{}([&](auto m0) {
|
||||
token_offsets(m0) = p_sorted_token_ids[token_pos + MLoadThreads * m0] * problem.K;
|
||||
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
|
||||
constexpr auto AKThreads = AK0Threads * AK1Threads;
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
gather_offsets(m0) = p_sorted_token_ids[token_pos + m0] * problem.K;
|
||||
printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
|
||||
});
|
||||
// printf("threadIdx.x %d off %d\n", threadIdx.x, token_offsets(I0));
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
@@ -1194,7 +1196,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
token_offsets);
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
@@ -1222,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(AK0Threads, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
|
||||
|
||||
// Blockwise GEMM pipeline
|
||||
|
||||
@@ -178,15 +178,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
|
||||
// maintain a container record is_src_valid, waiting for RunWrite use.
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize() * sizeof(SrcData);//hack felix, todo use coord
|
||||
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord
|
||||
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512);
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
// if(blockIdx.x+blockIdx.y==0)
|
||||
// printf("tid %d off %d %d\n", threadIdx.x, src_coord_.GetOffset(), gather_offset );
|
||||
if(threadIdx.x==0)
|
||||
printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset );
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
|
||||
|
||||
@@ -235,7 +235,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
@@ -246,15 +246,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
move_on_dim_(i) &= i.value != ordered_gather_dim;
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim);
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
if(threadIdx.x==0)
|
||||
printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]);
|
||||
if (move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
@@ -267,7 +272,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
if(threadIdx.x==0)
|
||||
printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset());
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
|
||||
@@ -423,14 +423,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
dst_coords_[i].GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
if(1) {
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
|
||||
using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
|
||||
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
|
||||
});
|
||||
}
|
||||
// if(1) {
|
||||
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
|
||||
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
|
||||
Reference in New Issue
Block a user