mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add instances/ckProfiler/client example for fp8/fp16 mixed precision Gemm (#853)
* Add ComputeType arg to splitk device and gridwise ops * Update for gridwise op compatibility * Update bf16 and int8 splitk gemm examples with ComputeType * Add instances * Update ckProfiler for mixed precision cases * Add a mixed precision splitK gemm client example --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -45,7 +45,8 @@ __global__ void
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
typename ALayout,
|
||||
@@ -85,7 +86,8 @@ template <index_t BlockSize,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeType = FloatC>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -113,8 +115,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
const FloatAB* p_a_grid;
|
||||
const FloatAB* p_b_grid;
|
||||
const FloatA* p_a_grid;
|
||||
const FloatB* p_b_grid;
|
||||
FloatC* p_c_grid;
|
||||
index_t M;
|
||||
index_t N;
|
||||
@@ -128,8 +130,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
index_t K0;
|
||||
index_t k_batch;
|
||||
|
||||
Argument(const FloatAB* p_a_grid_,
|
||||
const FloatAB* p_b_grid_,
|
||||
Argument(const FloatA* p_a_grid_,
|
||||
const FloatB* p_b_grid_,
|
||||
FloatC* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
@@ -365,7 +367,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
@@ -577,8 +579,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
void* __restrict__ p_shared_block,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const FloatAB* p_a_grid = karg.p_a_grid;
|
||||
const FloatAB* p_b_grid = karg.p_b_grid;
|
||||
const FloatA* p_a_grid = karg.p_a_grid;
|
||||
const FloatB* p_b_grid = karg.p_b_grid;
|
||||
FloatC* p_c_grid = karg.p_c_grid;
|
||||
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
|
||||
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
|
||||
@@ -698,8 +700,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatA,
|
||||
ComputeType,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -728,8 +730,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatB,
|
||||
ComputeType,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -759,7 +761,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
ComputeType,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
@@ -776,8 +778,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
|
||||
FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
|
||||
ComputeType* p_a_block = static_cast<ComputeType*>(p_shared_block);
|
||||
ComputeType* p_b_block = static_cast<ComputeType*>(p_shared_block) + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
Reference in New Issue
Block a user