wip: trying to addd async

This commit is contained in:
Gino Lu
2025-07-08 22:34:30 -05:00
parent 5557eadce6
commit 2b95d3a0aa
2 changed files with 46 additions and 17 deletions

View File

@@ -24,7 +24,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif
__global__ void kentry(Args... args)
{
#if defined(__HIP_DEVICE_COMPILE__)
Kernel{}(args...);
# else
(..., (ignore = args));
#endif
}
//

View File

@@ -156,7 +156,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
@@ -249,7 +249,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
//auto a_block_tile = load_tile(a_copy_dram_window);
// TODO: gino: async
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
@@ -260,6 +261,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor;
// gino: vgpr?
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
@@ -274,6 +276,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
// gino: vgpr, so can't async?
});
});
@@ -288,18 +291,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
// TODO: gino: Does command support shuffle and async load?
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
assert(false); // not support
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
// shuffle_tile(a_shuffle_tmp, a_block_tile);
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
// store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
//store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
async_load_tile(a_copy_lds_window, a_copy_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
//block_sync_lds();
}
// TODO: gino: use this to sync async load.
block_sync_lds();
}
@@ -307,7 +317,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
// TODO: gino: async
//a_block_tile = load_tile(a_copy_dram_window);
// GEMM i
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
@@ -324,6 +335,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// gino: to vgpr
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
@@ -332,15 +344,19 @@ struct FlatmmPipelineAGmemBGmemCRegV1
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// LDS write i + 1
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
// store_tile(a_copy_lds_window, a_block_tile_tmp);
// TODO: gino: combine with async?
async_load_tile(a_copy_lds_window, a_copy_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
HotLoopScheduler();
block_sync_lds();
// iCounter--;
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
//a_block_tile = load_tile(a_copy_dram_window);
// TODO: gino: async
// GEMM i
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
@@ -355,6 +371,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
// TODO: gino: to vgpr, can't async
});
});
@@ -365,9 +382,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// LDS write i + 1
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
//a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
//store_tile(a_copy_lds_window, a_block_tile_tmp);
// gino: combine with async
async_load_tile(a_copy_lds_window, a_copy_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
HotLoopScheduler();
block_sync_lds();
@@ -377,7 +397,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// tail
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
//a_block_tile = load_tile(a_copy_dram_window);
// TODO: gino: async
// GEMM i
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
@@ -392,6 +413,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
// gino: to vgpr, can't async
});
});
@@ -399,12 +421,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
//const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
//store_tile(a_copy_lds_window, a_block_tile_tmp);
// TODO: gino: combine with async
// move to next flat K
// move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
async_load_tile(a_copy_lds_window, a_copy_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
HotLoopScheduler();
block_sync_lds();