mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
feature:tf32:add initial conv3d fwd kernel support (#2763)
This commit is contained in:
@@ -21,14 +21,15 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
F32_F32_F32_TF32, // 8
|
||||
};
|
||||
|
||||
enum struct IndexType
|
||||
@@ -53,6 +54,7 @@ static void print_helper_msg()
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< " 8: Input fp32, Weight fp32, Output fp32, Compute tf32\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
|
||||
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
@@ -103,6 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using INT8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
//
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
@@ -261,6 +264,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
// NHWGC_GKYXC_NHWGK
|
||||
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -367,6 +374,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
// NGCDHW_GKCZYX_NGKDHW
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
@@ -384,6 +395,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
return profile(
|
||||
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user