Simulate TF32 with BF16x3 (#3142)

* tf32:bf16x3:use bf16x3 emulate tf32 gemm

* change blockwiseGemm to demo bf16x3

* temp push

* self review

* self review

* fix multi-device compile error

* bug fix

* code refactor

* limit to gfx950

* enhance gemm gfx942 threshold

* lower change from blockwise to warpwise

* refact codes

* refact codes

* error fix

* change threshold

* bug fix

* fix threshold error

* change host reference implement to same as device

* bug fix

* bug fix

* code refact

* fix clang-format fail

* code refine
This commit is contained in:
yinglu
2025-11-14 08:21:09 +08:00
committed by GitHub
parent f2cfc6b94e
commit 2a73eb3bc0
16 changed files with 419 additions and 49 deletions

View File

@@ -105,7 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using INT8 = int8_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
using TF32 = ck::tf32_t;
#endif
@@ -228,7 +228,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -253,7 +253,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -280,7 +280,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -306,7 +306,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -331,7 +331,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -352,7 +352,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -373,7 +373,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -416,7 +416,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
@@ -439,7 +439,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}