Small post-merge fixes

This commit is contained in:
kiefer
2025-12-15 16:27:17 +00:00
parent dbb2e39386
commit 291c6fef56
3 changed files with 42 additions and 25 deletions

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -454,6 +454,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
BComputeDataType,
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
// TODO: Previously available template param DoElementwiseBeforeCShuffle!
@@ -527,6 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
false, // PermuteB
false, // PermuteA
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
using GridwiseGemmCTranspose =

View File

@@ -865,6 +865,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
@@ -875,6 +879,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -894,6 +899,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
}

View File

@@ -1,3 +1,6 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include <iomanip>
#include <iostream>
@@ -89,12 +92,12 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
Tensor<InDataType> input(in_g_n_c_wis_desc);
Tensor<InDataType> input_bias(in_g_n_c_wis_desc);
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<WeiDataType> weight_bias(wei_g_k_c_xs_desc);
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
ck::Tensor<InDataType> input(in_g_n_c_wis_desc);
ck::Tensor<InDataType> input_bias(in_g_n_c_wis_desc);
ck::Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck::Tensor<WeiDataType> weight_bias(wei_g_k_c_xs_desc);
ck::Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
ck::Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
@@ -116,11 +119,12 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
weight_bias.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
}
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem in_bias_device_buf(sizeof(InDataType) * input_bias.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem wei_bias_device_buf(sizeof(WeiDataType) * weight_bias.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
ck::DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
ck::DeviceMem in_bias_device_buf(sizeof(InDataType) * input_bias.mDesc.GetElementSpaceSize());
ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
ck::DeviceMem wei_bias_device_buf(sizeof(WeiDataType) *
weight_bias.mDesc.GetElementSpaceSize());
ck::DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data());
in_bias_device_buf.ToDevice(input_bias.mData.data());
@@ -130,8 +134,8 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
// Run reference op
if(do_verification)
{
const std::array<Tensor<InDataType>, NumAs - 1> elementwise_a_tensors = {input_bias};
const std::array<Tensor<WeiDataType>, NumBs - 1> elementwise_b_tensors = {weight_bias};
const std::array<ck::Tensor<InDataType>, NumAs - 1> elementwise_a_tensors = {input_bias};
const std::array<ck::Tensor<WeiDataType>, NumBs - 1> elementwise_b_tensors = {weight_bias};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
@@ -175,7 +179,7 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
// workspace_sz will be equal to 0 for other layout than NGCHW
// TODO: Is workspace even necessary?
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
ck::DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
@@ -222,16 +226,21 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
if(do_log)
{
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "input_bias: ", input_bias.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "weight_bias: ", weight_bias.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "host_output : ", host_output.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
<< std::endl;
printf("log\n");
// LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "input_bias: ",
// input_bias.mData, ",")
// << std::endl;
// LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "weight_bias: ",
// weight_bias.mData, ",")
// << std::endl;
// LogRangeAsType<float>(std::cout << "host_output : ", host_output.mData,
// ",")
// << std::endl;
// LogRangeAsType<float>(std::cout << "device_output: ",
// device_output.mData, ",")
// << std::endl;
}
}
}