mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
compiles again using get_y_sliced_thread_data in warpgemm loop
This commit is contained in:
@@ -299,7 +299,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
@@ -309,7 +309,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
@@ -364,6 +364,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
scale_b_dram_window.get_load_offset(tuple<number<0>, number<NWarp * NPerXdl>>{}));
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
// TODO: check for packed size - are these blocks too big?
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
|
||||
@@ -372,14 +373,14 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{});
|
||||
}();
|
||||
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{});
|
||||
}();
|
||||
|
||||
// LDS tile windows for storing, one per LDS buffer
|
||||
@@ -439,6 +440,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * KPerXdl) / (ScaleBlockSize * KXdlPack);
|
||||
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter is wrong!");
|
||||
|
||||
// Load a sample scale tile to get the type
|
||||
auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{});
|
||||
@@ -520,7 +522,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
});
|
||||
|
||||
// Warp GEMM loop with MX scaling
|
||||
auto warp_gemm_loop = [&](auto& a_block_tile, auto& b_block_tile, auto& scale_a, auto& scale_b) {
|
||||
auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) {
|
||||
// Extract A/B values from block tiles to warp iteration structure
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
@@ -537,25 +539,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
|
||||
constexpr auto OpSelA = kScaleInPack;
|
||||
|
||||
// read A warp tensor from A block tensor
|
||||
typename WarpGemm::AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
|
||||
constexpr auto OpSelB = kScaleInPack;
|
||||
|
||||
// Extract A/B values for this iteration - create warp tensors
|
||||
typename WarpGemm::AWarpTensor a_warp_tensor{};
|
||||
const auto a_thread_data = a_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_for<0, a_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
|
||||
a_warp_tensor.get_thread_buffer()(i) = a_thread_data[i];
|
||||
});
|
||||
|
||||
typename WarpGemm::BWarpTensor b_warp_tensor{};
|
||||
const auto b_thread_data = b_block_tile.get_y_sliced_thread_data(
|
||||
// read B warp tensor from B block tensor
|
||||
typename WarpGemm::BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, b_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
|
||||
b_warp_tensor.get_thread_buffer()(i) = b_thread_data[i];
|
||||
});
|
||||
|
||||
WarpGemm{}.template operator()<OpSelA, OpSelB>(
|
||||
c_warp_tensors(m_iter)(n_iter),
|
||||
|
||||
@@ -29,9 +29,9 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
// Get packed sizes for A/B
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
// Force 16-byte vector loads for optimal async buffer performance
|
||||
// For fp4 (1 byte): 16 elements = 16 bytes
|
||||
// For fp8 (1 byte): 16 elements = 16 bytes
|
||||
@@ -53,9 +53,9 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
// Get packed sizes for A/B
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
// Force 16-byte vector loads for optimal async buffer performance
|
||||
// For fp4 (1 byte): 16 elements = 16 bytes
|
||||
// For fp8 (1 byte): 16 elements = 16 bytes
|
||||
@@ -86,13 +86,17 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
// Get packed sizes for A/B
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
KPerBlock / APackedSize,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
@@ -123,6 +127,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
|
||||
// Get packed sizes for A/B
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -141,7 +150,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
KPerBlock / BPackedSize,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
@@ -153,8 +162,13 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
// Get packed sizes for A/B
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize;
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
@@ -191,8 +205,13 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
// Get packed sizes for A/B
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize;
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
@@ -300,10 +319,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma
|
||||
|
||||
static_assert(K_Lane == 4, "K_Lane must be 4 for 16x16 mfma");
|
||||
static_assert(NPerXdl == 16, "NPerXdl must be 16 for 16x16 mfma");
|
||||
static_assert(MWarp == 1, "MWarp must be 1 for 16x16 mfma");
|
||||
|
||||
// Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile
|
||||
// Layout is [K, N] where K is packed int32
|
||||
|
||||
Reference in New Issue
Block a user