mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
fix settings for example, fix some things in pipeline
This commit is contained in:
@@ -242,7 +242,7 @@ struct MXGemmPipelineAgBgCrV1
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
};
|
||||
|
||||
// Helper for Math Loop
|
||||
// Helper for Main Loop
|
||||
auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) {
|
||||
// Define register tiles types for double buffering
|
||||
using AValType = decltype(load_tile_with_offset(a_warp_window, tuple<number<0>, number<0>>{}));
|
||||
|
||||
@@ -227,31 +227,10 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<K_Thread / AK1, K_Lane, AK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<2, 2>, // K_Thread/AK1, AK1
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
@@ -294,6 +273,29 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
|
||||
Reference in New Issue
Block a user