mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Added structural sparsity blockwise gemm (#1435)
* Implemented smfmac xdlops * Added smfmac blockwise xdlops * fixes * add reviewers suggestions --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t idx_part,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
|
||||
{
|
||||
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
|
||||
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
|
||||
a, b, idx, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
|
||||
static constexpr index_t k_per_blk = 16;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t idx_part,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
|
||||
{
|
||||
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
|
||||
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
|
||||
a, b, idx, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t idx_part,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
|
||||
{
|
||||
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
|
||||
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
|
||||
a, b, idx, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
|
||||
static constexpr index_t k_per_blk = 16;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t idx_part,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
|
||||
{
|
||||
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
|
||||
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
|
||||
a, b, idx, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
|
||||
"base base_type must be half or bfloat16!");
|
||||
|
||||
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
smfmac_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread);
|
||||
smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
|
||||
p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user