mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
initial draft
This commit is contained in:
@@ -283,24 +283,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window(s) for load
|
||||
// A DRAM tile window(s) for async byte-based load
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
return Policy::template MakeAAsyncLoadBytesDramWindow<Problem>(
|
||||
a_dram_block_window_tmp[number<idx>{}]);
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
// B DRAM window(s) for load
|
||||
// B DRAM tile window(s) for async byte-based load
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
return Policy::template MakeBAsyncLoadBytesDramWindow<Problem>(
|
||||
b_dram_block_window_tmp[number<idx>{}]);
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
@@ -334,21 +328,24 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
|
||||
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
|
||||
|
||||
// initialize DRAM window steps, used to advance the DRAM windows
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
// initialize DRAM window steps for byte-based windows
|
||||
// Note: byte-based windows already account for data type packing
|
||||
const auto a_dram_tile_window_step =
|
||||
make_tuple(number<0>{}, number<KPerBlock / APackedSize>{});
|
||||
const auto b_dram_tile_window_step =
|
||||
make_tuple(number<0>{}, number<KPerBlock / BPackedSize>{});
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
// Define async load tile lambda
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
|
||||
};
|
||||
|
||||
// read A(0), B(0) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
async_load_tile_(a_copy_lds_window0, a_tile_windows[number<0>{}]);
|
||||
move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
async_load_tile_(b_copy_lds_window0, b_tile_windows[number<0>{}]);
|
||||
move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// initialize block gemm
|
||||
auto block_gemm = BlockGemm();
|
||||
@@ -359,10 +356,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
|
||||
// read A(1), B(1) from DRAM to LDS window(1)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
async_load_tile_(a_copy_lds_window1, a_tile_windows[number<0>{}]);
|
||||
move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
async_load_tile_(b_copy_lds_window1, b_tile_windows[number<0>{}]);
|
||||
move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// tile distribution for the register tiles
|
||||
constexpr auto ALdsTileDistr =
|
||||
@@ -423,10 +420,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_sync_lds();
|
||||
// read A(2), B(2) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
async_load_tile_(a_copy_lds_window0, a_tile_windows[number<0>{}]);
|
||||
move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
async_load_tile_(b_copy_lds_window0, b_tile_windows[number<0>{}]);
|
||||
move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
if(HasHotLoop)
|
||||
{
|
||||
@@ -445,12 +442,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_sync_lds();
|
||||
// read A(i), B(i) from DRAM to LDS window(1)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(a_copy_lds_window1,
|
||||
a_tile_windows[number<0>{}],
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(b_copy_lds_window1,
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
async_load_tile_(a_copy_lds_window1, a_tile_windows[number<0>{}]);
|
||||
move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
async_load_tile_(b_copy_lds_window1, b_tile_windows[number<0>{}]);
|
||||
move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
// C(i-3) = A(i-3) @ B(i-3)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
HotLoopScheduler();
|
||||
@@ -466,12 +461,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_sync_lds();
|
||||
// read A(i+1), B(i+1) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(a_copy_lds_window0,
|
||||
a_tile_windows[number<0>{}],
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(b_copy_lds_window0,
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
async_load_tile_(a_copy_lds_window0, a_tile_windows[number<0>{}]);
|
||||
move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
async_load_tile_(b_copy_lds_window0, b_tile_windows[number<0>{}]);
|
||||
move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
// C(i-2) = A(i-2) @ B(i-2)
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -18,6 +18,9 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
static constexpr index_t DWORDx4 = 16;
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
@@ -93,6 +96,141 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
}
|
||||
}
|
||||
|
||||
// Methods for async byte-based loading (similar to mx_flatmm)
|
||||
template <typename Problem, typename WindowTmp>
|
||||
CK_TILE_DEVICE static constexpr auto MakeAAsyncLoadBytesDramWindow(const WindowTmp& window_tmp)
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto ndims = std::decay_t<decltype(window_tmp)>::get_num_of_dimension();
|
||||
static_assert(ndims == 2, "only support 2D tensor");
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
|
||||
constexpr index_t K2 = DWORDx4;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4;
|
||||
const index_t K0 = cols / (K1 * K2 * APackedSize);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4;
|
||||
const index_t M0 = integer_divide_ceil(rows, M1);
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_0 = decltype(d0)(
|
||||
d0.get_transforms(), tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
const auto desc =
|
||||
transform_tensor_descriptor(desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
make_merge_transform_v3_division_mod(col_lens)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view = make_tensor_view<address_space_enum::global>(byte_ptr, desc);
|
||||
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
|
||||
// Create tile distribution inline (reuse K2, K1, K0 from above)
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t M2_dstr = WaveSize / K1;
|
||||
constexpr index_t M1_dstr = BlockSize / WaveSize;
|
||||
constexpr index_t M0_dstr = MPerBlock / (M2_dstr * M1_dstr);
|
||||
|
||||
constexpr auto tile_dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M0_dstr, M1_dstr, M2_dstr>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
|
||||
return make_tile_window(byte_tensor_view,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / APackedSize},
|
||||
tile_dstr);
|
||||
}
|
||||
|
||||
template <typename Problem, typename WindowTmp>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBAsyncLoadBytesDramWindow(const WindowTmp& window_tmp)
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto ndims = std::decay_t<decltype(window_tmp)>::get_num_of_dimension();
|
||||
static_assert(ndims == 2, "only support 2D tensor");
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
|
||||
constexpr index_t K2 = DWORDx4;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4;
|
||||
const index_t K0 = cols / (K1 * K2 * BPackedSize);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t N1 = 4;
|
||||
const index_t N0 = integer_divide_ceil(rows, N1);
|
||||
const auto row_lens = make_tuple(N0, number<N1>{});
|
||||
|
||||
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_0 = decltype(d0)(
|
||||
d0.get_transforms(), tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(N0),
|
||||
make_xor_transform(make_tuple(number<N1>{}, number<K1>{})),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
const auto desc =
|
||||
transform_tensor_descriptor(desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
make_merge_transform_v3_division_mod(col_lens)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view = make_tensor_view<address_space_enum::global>(byte_ptr, desc);
|
||||
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
|
||||
// Create tile distribution inline (reuse K2, K1, K0 from above)
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t N2_dstr = WaveSize / K1;
|
||||
constexpr index_t N1_dstr = BlockSize / WaveSize;
|
||||
constexpr index_t N0_dstr = NPerBlock / (N2_dstr * N1_dstr);
|
||||
|
||||
constexpr auto tile_dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_dstr, N1_dstr, N2_dstr>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
|
||||
return make_tile_window(byte_tensor_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / BPackedSize},
|
||||
tile_dstr);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user