mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
[GEMM] Fix bWarpTile issue and remove redundant pipeline in BlockGemmPipeline
This commit is contained in:
@@ -43,7 +43,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
// A block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
@@ -92,16 +92,16 @@ struct BlockGemmASmemBSmemCReg
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
ALdsTile aWarpTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Prefetch from LDS to warp register
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
aWarpTile = load_tile(a_block_window);
|
||||
bWarpTile = load_tile(b_block_window);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -178,23 +178,23 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message ("local data share prefetch")
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
#else
|
||||
load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter));
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
#else
|
||||
load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter));
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -305,22 +305,22 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
#else
|
||||
load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter));
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
#else
|
||||
load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter));
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
@@ -17,19 +17,31 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if defined(ADJUST_BLOCK_TILE_SHAPE)
|
||||
constexpr index_t kMWarp = 2;
|
||||
constexpr index_t kNWarp = 2;
|
||||
#else
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
#endif
|
||||
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("mfma m32 n32 k8")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_32x32x_8x2)
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
@@ -37,13 +49,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x16)
|
||||
#pragma message ("mfma m16 n16 k16")
|
||||
@@ -51,13 +67,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x_16x2)
|
||||
#pragma message ("mfma m16 n16 k32")
|
||||
@@ -65,13 +85,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{},
|
||||
kMWarp,
|
||||
kNWarp);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
|
||||
@@ -42,15 +42,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
}
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
static constexpr index_t APackedSize =
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -72,31 +65,31 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
constexpr index_t AB_LDS_RW_Width = GetSmemPack();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock /
|
||||
(kBlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num :
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? A_LDS_Read_Inst_Num :
|
||||
A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num :
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? B_LDS_Read_Inst_Num :
|
||||
B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
|
||||
@@ -109,9 +102,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
@@ -266,15 +259,15 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0},
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()));
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0},
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
// A LDS tile for block GEMM
|
||||
@@ -303,23 +296,23 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
if (num_loop > 1)
|
||||
{
|
||||
@@ -331,8 +324,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
@@ -357,8 +350,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
@@ -369,12 +362,14 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// Prefetch from LDS to warp register in block gemm
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
HotLoopScheduler();
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 2));
|
||||
iCounter += 1;
|
||||
} while(iCounter < (num_loop - 2));
|
||||
}
|
||||
|
||||
// Tail
|
||||
@@ -388,84 +383,12 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
#elif defined(ENABLE_PREFETCH)
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
|
||||
{
|
||||
// Move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
// Global read 1
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
// Global read 1
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
// Global read i + 2
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
|
||||
// LDS write i + 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
// Global read i + 2
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
|
||||
iCounter--;
|
||||
|
||||
} while(iCounter > 0);
|
||||
|
||||
// Tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
#else
|
||||
// non-prefetch
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
@@ -477,11 +400,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
while (iCounter > 0)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user