Merge remote-tracking branch 'origin/lwpck-3447' into mmflat

This commit is contained in:
AviralGoelAMD
2025-07-16 21:32:40 +00:00
23 changed files with 90 additions and 79 deletions

View File

@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
static const double epsilon = std::numeric_limits<float>::epsilon();
bool pass = true;
if(argc > 1)

View File

@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
bool pass = true;
static const double epsilon = std::numeric_limits<float>::epsilon();
bool pass = true;
if(argc > 1)
{

View File

@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
bool pass = true;
if(argc > 1)
{

View File

@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
bool pass = true;
if(argc > 1)
{

View File

@@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
add_dependencies(examples ${EXAMPLE_NAME})

View File

@@ -243,8 +243,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
@@ -265,8 +265,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;

View File

@@ -220,7 +220,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
auto [result, arg_parser] = create_args(argc, argv);
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && a_layout != "R" && b_layout != "C")
if(preshuffle && (a_layout != "R" || b_layout != "C"))
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");

View File

@@ -315,8 +315,16 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
if constexpr(preshuffle)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
}
else if(init_method == 1)
{

View File

@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}