mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Add grouped conv bwd weight wmma (#985)
* Add grouped conv bwd weight wmma
* Update README, changelog, profiler
* Minor fixes
* Fix grouped conv bwd wei dl kernel
* Minor fixes
* Minor stylistic fixes
[ROCm/composable_kernel commit: 16d7c4d2f7]
This commit is contained in:
@@ -20,10 +20,11 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F16_F16_F16_BF8_F8 // 3
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F16_F16_F16_BF8_F8, // 3
|
||||
I8_I8_I8 // 4
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_bwd_weight"
|
||||
@@ -35,7 +36,8 @@ static void print_helper_msg()
|
||||
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
|
||||
<< " 4: Input int8, Weight int8, Output int8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
@@ -196,6 +198,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -216,6 +223,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user