mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add Gemm instances for performance improvement (#1018)
* improve kpad * more tuning parameters * f16_f8_fp16 * cut test time * add f16_f8_fp16 * add f16_f8_f16 * testing instances for skinny cases * format * clean * add fp16_f8_fp16 * clang-format * add grouped gemm instalces * fixed profile grouped_gemm * clean * clean * clean * clean * clean * add missing instance func * fixed inferface --------- Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: root <root@sh5-1e707-rc06-38.mkm.dcgpu>
This commit is contained in:
@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
// clang-format off
|
||||
str << "DeviceGemm_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
<< " LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];;
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -59,7 +59,8 @@ template <typename ADataType,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename ComputeType = CDataType,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
|
||||
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
index_t MPadded_,
|
||||
index_t NPadded_,
|
||||
index_t KPadded_,
|
||||
index_t K0_,
|
||||
index_t K0Padded_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
MPadded_,
|
||||
NPadded_,
|
||||
KPadded_,
|
||||
K0_,
|
||||
K0Padded_,
|
||||
k_batch_),
|
||||
a_element_op(a_element_op_),
|
||||
b_element_op(b_element_op_),
|
||||
@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
const auto b2c_map = DefaultBlock2CTileMap{};
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0 = karg.K0;
|
||||
const auto K0Padded = karg.K0Padded;
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
GridwiseGemm::CalculateK0Padded(K, KBatch),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
GridwiseGemm::CalculateK0Padded(K, KBatch),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
|
||||
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
|
||||
|
||||
@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
m_padded,
|
||||
n_padded,
|
||||
k_padded,
|
||||
k0,
|
||||
k0_padded,
|
||||
K_BATCH};
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
|
||||
auto& karg = gemm_kernel_args_[i].karg_;
|
||||
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
karg.KPadded = k_padded;
|
||||
karg.K0 = k0;
|
||||
karg.K0Padded = k0_padded;
|
||||
karg.k_batch = K_BATCH;
|
||||
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
|
||||
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
K0 = karg.K0;
|
||||
K0 = karg.K0Padded;
|
||||
bool not_all_have_main_k0_block_loop_same =
|
||||
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
|
||||
|
||||
Reference in New Issue
Block a user