mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Enable grouped conv bwd wei bf16 NGCHW (#1589)
* Enable grouped conv bwd wei bf16 NGCHW * fixes * fixes * Fixes * fixes * fixes * Fixes
This commit is contained in:
@@ -25,7 +25,8 @@ enum struct ConvDataType
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F16_F16_F16_BF8_F8, // 3
|
||||
I8_I8_I8 // 4
|
||||
I8_I8_I8, // 4
|
||||
BF16_BF16_BF16, // 5
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_bwd_weight"
|
||||
@@ -38,7 +39,8 @@ static void print_helper_msg()
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
|
||||
<< " 4: Input int8, Weight int8, Output int8)\n"
|
||||
<< " 4: Input int8, Weight int8, Output int8\n"
|
||||
<< " 5: Input bf16, Weight bf16, Output bf16)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
@@ -187,6 +189,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -203,6 +210,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return profile(
|
||||
@@ -224,6 +236,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
|
||||
@@ -240,6 +257,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(
|
||||
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user