mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Add Gemm instances for performance improvement (#1018)
* improve kpad
* more tuning parameters
* f16_f8_fp16
* cut test time
* add f16_f8_fp16
* add f16_f8_f16
* testing instances for skinny cases
* format
* clean
* add fp16_f8_fp16
* clang-format
* add grouped gemm instalces
* fixed profile grouped_gemm
* clean
* clean
* clean
* clean
* clean
* add missing instance func
* fixed inferface
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: root <root@sh5-1e707-rc06-38.mkm.dcgpu>
[ROCm/composable_kernel commit: 98fd41f597]
This commit is contained in:
@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
// clang-format off
|
||||
str << "DeviceGemm_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
<< " LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];;
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -59,7 +59,8 @@ template <typename ADataType,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename ComputeType = CDataType,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
|
||||
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
index_t MPadded_,
|
||||
index_t NPadded_,
|
||||
index_t KPadded_,
|
||||
index_t K0_,
|
||||
index_t K0Padded_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
MPadded_,
|
||||
NPadded_,
|
||||
KPadded_,
|
||||
K0_,
|
||||
K0Padded_,
|
||||
k_batch_),
|
||||
a_element_op(a_element_op_),
|
||||
b_element_op(b_element_op_),
|
||||
@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
const auto b2c_map = DefaultBlock2CTileMap{};
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0 = karg.K0;
|
||||
const auto K0Padded = karg.K0Padded;
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
GridwiseGemm::CalculateK0Padded(K, KBatch),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
GridwiseGemm::CalculateK0Padded(K, KBatch),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
|
||||
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
|
||||
|
||||
@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
m_padded,
|
||||
n_padded,
|
||||
k_padded,
|
||||
k0,
|
||||
k0_padded,
|
||||
K_BATCH};
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
|
||||
auto& karg = gemm_kernel_args_[i].karg_;
|
||||
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
karg.KPadded = k_padded;
|
||||
karg.K0 = k0;
|
||||
karg.K0Padded = k0_padded;
|
||||
karg.k_batch = K_BATCH;
|
||||
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
|
||||
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
K0 = karg.K0;
|
||||
K0 = karg.K0Padded;
|
||||
bool not_all_have_main_k0_block_loop_same =
|
||||
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
|
||||
|
||||
@@ -16,6 +16,57 @@ namespace element_wise {
|
||||
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
|
||||
{
|
||||
// fake conversion
|
||||
uint16_t t = ck::bit_cast<uint32_t>(x);
|
||||
y = ck::bit_cast<ck::f8x2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
|
||||
@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t KPadded;
|
||||
index_t K0;
|
||||
index_t K0Padded;
|
||||
index_t k_batch;
|
||||
|
||||
Argument(const FloatA* p_a_grid_,
|
||||
@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
index_t MPadded_,
|
||||
index_t NPadded_,
|
||||
index_t KPadded_,
|
||||
index_t K0_,
|
||||
index_t K0Padded_,
|
||||
index_t k_batch_)
|
||||
: p_a_grid(p_a_grid_),
|
||||
p_b_grid(p_b_grid_),
|
||||
@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
MPadded(MPadded_),
|
||||
NPadded(NPadded_),
|
||||
KPadded(KPadded_),
|
||||
K0(K0_),
|
||||
K0Padded(K0Padded_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KP:" << KPadded << ", "
|
||||
<< "K0:" << K0 << ", "
|
||||
<< "K0Padded:" << K0Padded << ", "
|
||||
<< "KB:" << k_batch << "}" << std::endl;
|
||||
}
|
||||
};
|
||||
@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
return math::integer_least_multiple(N, NPerBlock);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
// k_batch * k0 * k0_per_block * k1
|
||||
auto K_t = K_Batch * K0PerBlock * K1;
|
||||
@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K0 = CalculateK0(K, K_Batch);
|
||||
return K_Batch * K0 * K1;
|
||||
auto K0Padded = CalculateK0Padded(K, K_Batch);
|
||||
return K_Batch * K0Padded * K1;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
|
||||
@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t KBatch,
|
||||
index_t K0,
|
||||
index_t K0Padded,
|
||||
index_t KPad)
|
||||
{
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
{
|
||||
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
index_t N,
|
||||
index_t StrideB,
|
||||
index_t KBatch,
|
||||
index_t K0,
|
||||
index_t K0Padded,
|
||||
index_t KPad)
|
||||
{
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
{
|
||||
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
|
||||
auto K_t = karg.k_batch * K0PerBlock * K1;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout
|
||||
<< "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout
|
||||
<< "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto num_k_loop = karg.K0 / K0PerBlock;
|
||||
const auto num_k_loop = karg.K0Padded / K0PerBlock;
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The number of k loops (" << num_k_loop
|
||||
<< ") value is not supported by GridwiseGemm Pipeline."
|
||||
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
<< " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
|
||||
const index_t KPad = KBatch * K0 * K1;
|
||||
const index_t K0Padded =
|
||||
math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
|
||||
const index_t KPad = KBatch * K0Padded * K1;
|
||||
return KPad;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
|
||||
{
|
||||
const index_t num_loop = K0 / K0PerBlock;
|
||||
const index_t num_loop = K0Padded / K0PerBlock;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
const FloatB* p_b_grid = karg.p_b_grid;
|
||||
FloatC* p_c_grid = karg.p_c_grid;
|
||||
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
|
||||
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
|
||||
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
|
||||
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
|
||||
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
|
||||
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
|
||||
@@ -140,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
const auto i16val = bit_cast<uint16_t>(x);
|
||||
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
const auto f8x2_v = vector_type<f8_t, 2>(x);
|
||||
vector_type<float, 2> f32x2_v;
|
||||
f32x2_v.template AsType<float>()(Number<0>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
|
||||
f32x2_v.template AsType<float>()(Number<1>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
|
||||
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
|
||||
{
|
||||
|
||||
const vector_type<float, 2> f32x2_v(x);
|
||||
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
|
||||
f32x2_v.template AsType<float>()[Number<1>{}]);
|
||||
return bit_cast<half2_t>(y);
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
|
||||
Reference in New Issue
Block a user