mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 16:04:58 +00:00
Simulate TF32 with BF16x3 (#3142)
* tf32:bf16x3:use bf16x3 emulate tf32 gemm * change blockwiseGemm to demo bf16x3 * temp push * self review * self review * fix multi-device compile error * bug fix * code refactor * limit to gfx950 * enhance gemm gfx942 threshold * lower change from blockwise to warpwise * refact codes * refact codes * error fix * change threshold * bug fix * fix threshold error * change host reference implement to same as device * bug fix * bug fix * code refact * fix clang-format fail * code refine
This commit is contained in:
@@ -105,7 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using INT8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
@@ -228,7 +228,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -253,7 +253,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -280,7 +280,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -306,7 +306,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -331,7 +331,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -352,7 +352,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -373,7 +373,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -416,7 +416,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -439,7 +439,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user