mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Fix issue with multiple targets and remove smfmac tests from unsupported test targets (#1372)
[ROCm/composable_kernel commit: 959073842c]
This commit is contained in:
@@ -47,12 +47,12 @@ __global__ void
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
@@ -103,12 +103,12 @@ __global__ void
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
|
||||
@@ -69,14 +69,15 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
const index_t groups_count)
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t groups_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
@@ -132,13 +133,13 @@ __global__ void
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
const index_t groups_count)
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t groups_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
|
||||
@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
|
||||
__device__ static void
|
||||
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
ignore = reg_idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
|
||||
__device__ static void
|
||||
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
ignore = reg_idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
|
||||
__device__ static void
|
||||
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
ignore = reg_idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
|
||||
__device__ static void
|
||||
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
ignore = reg_idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user