tmp save between remotes

This commit is contained in:
Jakub Piasecki
2026-01-28 16:07:11 +00:00
parent 181c075794
commit eb3eacebce
5 changed files with 35 additions and 5 deletions

View File

@@ -187,6 +187,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
if(threadIdx.x == 0) printf("intra\n");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
@@ -212,6 +213,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
index_t i = 0;
do
{
if(threadIdx.x == 0) printf("hotloop: %d\n", i);
// -------------------------------------------------------------------------------------------
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
@@ -280,6 +282,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
if(threadIdx.x == 0) printf("tail\n");
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
@@ -919,6 +922,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
if(threadIdx.x == 0) printf("v1 intra directload, num_loop: %d\n", num_loop);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
@@ -942,6 +946,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
index_t i = 0;
do
{
if(threadIdx.x == 0) printf("has Main loop %d\n", i);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
@@ -981,6 +986,14 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
if(threadIdx.x == 0) {
printf("a: %f b: %f\n",
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]),
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]));
}
});
using mfma_input_type =
@@ -1007,6 +1020,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
// tail
if constexpr(TailNum == TailNumber::Full)
{
if(threadIdx.x == 0) printf("Tail full\n");
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
@@ -1039,6 +1053,14 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
if(threadIdx.x == 0) {
printf("a: %f b: %f\n",
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]),
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]));
}
});
using mfma_input_type =

View File

@@ -909,6 +909,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
const bool HasMainKBlockLoop =
GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(K_split);
printf("GemmK: %d split_k: %d, KPerBlock: %d, k_grain: %d, k_split: %d", GemmK, split_k, KPerBlock, k_grain, K_split);
gemm_kernel_args_[gemms_count_ /
MaxGroupedGemmGroupsNum][gemms_count_ %
MaxGroupedGemmGroupsNum] =

View File

@@ -1640,7 +1640,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcVectorDim, // enforced earlier
1, // enforced earlier
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
@@ -2297,6 +2297,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
if(threadIdx.x == 0) {
printf("num_k block main loop: %d\n m_block_data_idx_on_grid: %d\n n_block_data_idx_on_grid: %d\n", num_k_block_main_loop, m_block_data_idx_on_grid, n_block_data_idx_on_grid);
}
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,