mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge remote-tracking branch 'origin/lwpck-3447' into mmflat
This commit is contained in:
@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
|
||||
@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
|
||||
@@ -243,8 +243,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
@@ -265,8 +265,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
|
||||
@@ -220,7 +220,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
if(preshuffle && (a_layout != "R" || b_layout != "C"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
|
||||
@@ -315,8 +315,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user