mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Add conv bwd weight fp16 comp bf8 fp8 op, instances and example (#945)
* Add f8 bf8 gemm example * Add element-wise ops * Add intrinsics * Update reference calculation * Add an additional type option for xdlops gemm * Fix build process * Add bf8 to buffer addressing * Update blockwise op, split typeA and typeB * Update for compatibility * Uppdate naming to f8->fp8 * Update naming * Format * Update naming (#937) * Add a client example * Add computetypes to device and gridwise ops * Add instances, update instance factory * Format * Fix a flag * Add ckProfiler mode * Fix typos * Add an example * Add bf8 generator * add bf8 mfma; fixed type_convert for bf8 * move verfication ahead of timing * Update reference calculation * Fix reference * Narrow down float init range * Fix bf8 bf8 mfma * Add bf8 @ fp8 mfma * Update example * Update instances * Update profiler api * Update for compatibility * Format * Remove extra example * Clean up * workaround convert --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
@@ -153,8 +154,8 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
@@ -181,21 +182,22 @@ __global__ void
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
@@ -242,7 +244,9 @@ template <index_t BlockSize,
|
||||
bool ABlockLdsExtraM1Wrw = false,
|
||||
bool BBlockLdsExtraN1Wrw = false,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// denorm test fix, required to work around fp16 mfma issue
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
// FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
|
||||
// throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
using FloatAAdjusted =
|
||||
conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>;
|
||||
using FloatBAdjusted =
|
||||
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
using FloatAAdjusted = ComputeTypeA;
|
||||
using FloatBAdjusted = ComputeTypeB;
|
||||
#endif
|
||||
|
||||
// M0/M1/M1Padding
|
||||
@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
|
||||
return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
|
||||
b_block_space_size * sizeof(FloatBAdjusted)),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatABAdjusted,
|
||||
FloatA,
|
||||
FloatAAdjusted,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatABAdjusted,
|
||||
FloatB,
|
||||
FloatBAdjusted,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -733,12 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// sanity check
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
|
||||
math::max(K1,
|
||||
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatABAdjusted,
|
||||
FloatABAdjusted,
|
||||
FloatAAdjusted,
|
||||
FloatBAdjusted,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
@@ -758,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
|
||||
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
|
||||
Reference in New Issue
Block a user