mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
fixed c_buffer alloc
This commit is contained in:
@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, 16>,
|
||||
vector_type<FloatAcc, xdlops_gemm.GetRegSizePerXdlops()>,
|
||||
MRepeat * NRepeat,
|
||||
true>
|
||||
c_thread_buf_;
|
||||
@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
|
||||
@@ -557,6 +557,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
@@ -569,10 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
|
||||
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
|
||||
@@ -507,7 +507,7 @@ struct MfmaSelector
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
__host__ __device__ static constexpr void mfma_check()
|
||||
__host__ __device__ constexpr MfmaSelector()
|
||||
{
|
||||
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
|
||||
selected_mfma.num_regs_per_blk,
|
||||
@@ -533,8 +533,6 @@ struct MfmaSelector
|
||||
"is_k_reduction wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
|
||||
|
||||
static constexpr bool IsABroadcast()
|
||||
{
|
||||
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
|
||||
@@ -621,6 +619,8 @@ struct XdlopsGemm
|
||||
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 0
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
|
||||
Reference in New Issue
Block a user