mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
wip: trying to addd async
This commit is contained in:
@@ -24,7 +24,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
Kernel{}(args...);
|
||||
# else
|
||||
(..., (ignore = args));
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -156,7 +156,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
@@ -249,7 +249,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
//auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// TODO: gino: async
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
@@ -260,6 +261,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor;
|
||||
// gino: vgpr?
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
@@ -274,6 +276,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
// gino: vgpr, so can't async?
|
||||
});
|
||||
});
|
||||
|
||||
@@ -288,18 +291,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
// TODO: gino: Does command support shuffle and async load?
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
assert(false); // not support
|
||||
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
// shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
// store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
//store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
|
||||
async_load_tile(a_copy_lds_window, a_copy_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
//block_sync_lds();
|
||||
}
|
||||
// TODO: gino: use this to sync async load.
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
@@ -307,7 +317,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// TODO: gino: async
|
||||
//a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
|
||||
@@ -324,6 +335,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// gino: to vgpr
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
@@ -332,15 +344,19 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
// auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
// store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
// TODO: gino: combine with async?
|
||||
async_load_tile(a_copy_lds_window, a_copy_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
// iCounter--;
|
||||
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
//a_block_tile = load_tile(a_copy_dram_window);
|
||||
// TODO: gino: async
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
@@ -355,6 +371,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
// TODO: gino: to vgpr, can't async
|
||||
});
|
||||
});
|
||||
|
||||
@@ -365,9 +382,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
//a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
//store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
// gino: combine with async
|
||||
|
||||
async_load_tile(a_copy_lds_window, a_copy_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
@@ -377,7 +397,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// tail
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
//a_block_tile = load_tile(a_copy_dram_window);
|
||||
// TODO: gino: async
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
|
||||
@@ -392,6 +413,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
// gino: to vgpr, can't async
|
||||
});
|
||||
});
|
||||
|
||||
@@ -399,12 +421,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
//const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
//store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
// TODO: gino: combine with async
|
||||
|
||||
// move to next flat K
|
||||
// move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
async_load_tile(a_copy_lds_window, a_copy_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user