feature:tf32:add initial conv3d fwd kernel support (#2763)

This commit is contained in:
lym
2025-09-15 21:03:00 +08:00
committed by GitHub
parent e5d73da2da
commit c51102144f
44 changed files with 1085 additions and 175 deletions

View File

@@ -119,7 +119,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
PipelineVer,
ComputeDataType>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -214,6 +216,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false;
}
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;

View File

@@ -1003,11 +1003,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
void Print() const
{
std::cout << "AComputeDataType: " << get_type_name<AComputeDataType>()
<< "; BComputeDataType: " << get_type_name<BComputeDataType>()
<< "; EDataType: " << get_type_name<EDataType>() << std::endl;
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_
<< std::endl;
}
// private:
@@ -1198,7 +1207,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
@@ -1281,7 +1289,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(NeedTransposeKernel)
{
const index_t a_grid_size =
@@ -1686,7 +1693,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
// check Gridwise GEMM
if(get_warp_size() == 64)
{
@@ -1766,6 +1789,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!(ck::get_device_name() == "gfx942"))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "TF32 is enabled on gfx942 only" << std::endl;
}
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
return false;
}