-
Notifications
You must be signed in to change notification settings - Fork 239
Wmma support for grouped convolution bwd weight #2947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…/conv_bwd_weight_wmma' Convolution bwd weight device implementation See merge request amd/ai/composable_kernel!38
- rdna3 compilation error - gridwise layouts (need to be correct to ensure that CheckValidaity() works correctly)
…re/conv_bwd_weight_wmma' Grouped conv: Instances and example bwd weight See merge request amd/ai/composable_kernel!47
Based on batched gemm multiple D
Device implementation of explicit gemm for grouped conv bwd weight See merge request amd/ai/composable_kernel!52
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces WMMA (Wave Matrix Multiply Accumulate) support for grouped convolution backward weight operations, adding a comprehensive alternative to the existing XDL implementation for enhanced performance on supported hardware.
- Adds complete WMMA device-level implementations for grouped convolution backward weight operations
- Introduces batched GEMM with multiple D tensors for explicit GEMM-based convolution implementations
- Extends support for occupancy-based split-k optimization to both one-stage and two-stage implementations
Reviewed Changes
Copilot reviewed 128 out of 128 changed files in this pull request and generated 6 comments.
Show a summary per file
File | Description |
---|---|
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp | New comprehensive test suite for grouped convolution backward weight scale operations |
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp | Enhanced test implementation with improved error threshold calculations for split-k operations |
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp | Simplified test configuration by removing GPU architecture-specific constraints |
test/grouped_convnd_bwd_weight/CMakeLists.txt | Updated build configuration to include new test executables and reorganize dependencies |
profiler/src/CMakeLists.txt | Reorganized device instance dependencies in profiler build system |
library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/ | Function name updates to distinguish XDL implementations from WMMA variants |
library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/ | New WMMA-based explicit GEMM implementations for fp16 and bf16 data types |
library/src/tensor_operation_instance/gpu/grouped_conv*d_bwd_weight/wmma/ | Comprehensive WMMA device implementations for 1D, 2D, and 3D grouped convolutions |
library/include/ck/library/tensor_operation_instance/gpu/ | Updated header files with new WMMA instance declarations and factory methods |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
||
using AccDataType = float; | ||
float max_accumulated_value = | ||
*std::max_element(wei_host.mData.begin(),wei_host.mData.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after comma in std::max_element call. Should be 'wei_host.mData.begin(), wei_host.mData.end()'.
*std::max_element(wei_host.mData.begin(),wei_host.mData.end()); | |
*std::max_element(wei_host.mData.begin(), wei_host.mData.end()); |
Copilot uses AI. Check for mistakes.
|
||
using AccDataType = float; | ||
float max_accumulated_value = | ||
*std::max_element(wei_host.mData.begin(),wei_host.mData.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after comma in std::max_element call. Should be 'wei_host.mData.begin(), wei_host.mData.end()'.
*std::max_element(wei_host.mData.begin(),wei_host.mData.end()); | |
*std::max_element(wei_host.mData.begin(), wei_host.mData.end()); |
Copilot uses AI. Check for mistakes.
ConvBwdWeightDefault, | ||
BlockGemmPipelineScheduler::Intrawave, | ||
BlockGemmPipelineVersion::v1>{}); | ||
; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stray semicolon after function call. This should be removed.
; |
Copilot uses AI. Check for mistakes.
PassThrough, | ||
PassThrough>>>& instances) | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra blank line before add_device_operation_instances call. Should be removed for consistency.
Copilot uses AI. Check for mistakes.
PassThrough>>>& instances); | ||
|
||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( | ||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Function declaration ordering is inconsistent. Two-stage functions should be grouped together after the regular instance functions for better readability and maintainability.
Copilot uses AI. Check for mistakes.
PassThrough, | ||
PassThrough, | ||
PassThrough>>>& instances); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Function declaration ordering is inconsistent. Two-stage functions should be grouped together after the regular instance functions for better readability and maintainability.
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review will be continued
1, // CShuffleMRepeatPerShuffle | ||
1, // CShuffleNRepeatPerShuffle | ||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock | ||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CShuffleBlockTransferScalarPerVector_NPerBlock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we dont extend GNWHC layout format? Please change this example to NHWGC
ck::tensor_layout::convolution::GNDHWC>>, | ||
ck::tuple_element_t<NDimSpatial - 1, | ||
ck::Tuple<ck::tensor_layout::convolution::GKXC, | ||
ck::tensor_layout::convolution::GKYXC, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change this example to NHWGC
ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>( | ||
max_accumulated_value / num_accums_split_k, | ||
num_accums / num_accums_split_k); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use default check_err for split_k ==1 because calculated threshold could hide some errors
#if defined(__gfx11__) | ||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions | ||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>; | ||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such condition is not causing some register spills? I remember such issue with if constexpr inside global func
#if(defined(__gfx11__) || defined(__gfx12__)) | ||
#if defined(__gfx11__) | ||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions | ||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>; | |
using EDataType= remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>; |
Proposed changes
Summary:
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
,DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
andDeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
The implementations are based on CShuffleV3 but the functionality is the same as xdl.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered