mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
feature:tf32:add initial conv3d fwd kernel support (#2763)
This commit is contained in:
@@ -119,7 +119,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
PipelineVer,
|
||||
ComputeDataType>;
|
||||
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
@@ -214,6 +216,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check vector load/store.
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
@@ -1003,11 +1003,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "AComputeDataType: " << get_type_name<AComputeDataType>()
|
||||
<< "; BComputeDataType: " << get_type_name<BComputeDataType>()
|
||||
<< "; EDataType: " << get_type_name<EDataType>() << std::endl;
|
||||
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
|
||||
std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
@@ -1198,7 +1207,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
@@ -1281,7 +1289,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
@@ -1686,7 +1693,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// check Gridwise GEMM
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
@@ -1766,6 +1789,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "TF32 is enabled on gfx942 only" << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user