Skip to content

Conversation

EnricoDeg
Copy link
Contributor

Proposed changes

Summary:

  • Modify gridwise implementation to work with convolution (grid descriptors are not created internally but passed from the device level)
  • Add device level implementation: DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 , DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 and DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
  • Add device implementation of batched gemm multiple Ds (needed for explicit gemm - conv bwd weight)
  • Adapt existing device implementation of explicit gemm to work for both xdl and wmma implementations of batched gemm multiple Ds
  • Add support for occupancy-based splitk for one stage and two stage implementations of grouped conv bwd weight
  • Create instances
  • Add examples
  • Remove old instances (they don't support splitk)
  • Add tests for bwd weight scale

The implementations are based on CShuffleV3 but the functionality is the same as xdl.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

…/conv_bwd_weight_wmma'

Convolution bwd weight device implementation

See merge request amd/ai/composable_kernel!38
 - rdna3 compilation error
 - gridwise layouts (need to be correct to ensure that CheckValidaity()
   works correctly)
…re/conv_bwd_weight_wmma'

Grouped conv: Instances and example bwd weight

See merge request amd/ai/composable_kernel!47
Device implementation of explicit gemm for grouped conv bwd weight

See merge request amd/ai/composable_kernel!52
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces WMMA (Wave Matrix Multiply Accumulate) support for grouped convolution backward weight operations, adding a comprehensive alternative to the existing XDL implementation for enhanced performance on supported hardware.

  • Adds complete WMMA device-level implementations for grouped convolution backward weight operations
  • Introduces batched GEMM with multiple D tensors for explicit GEMM-based convolution implementations
  • Extends support for occupancy-based split-k optimization to both one-stage and two-stage implementations

Reviewed Changes

Copilot reviewed 128 out of 128 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp New comprehensive test suite for grouped convolution backward weight scale operations
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp Enhanced test implementation with improved error threshold calculations for split-k operations
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp Simplified test configuration by removing GPU architecture-specific constraints
test/grouped_convnd_bwd_weight/CMakeLists.txt Updated build configuration to include new test executables and reorganize dependencies
profiler/src/CMakeLists.txt Reorganized device instance dependencies in profiler build system
library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/ Function name updates to distinguish XDL implementations from WMMA variants
library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/ New WMMA-based explicit GEMM implementations for fp16 and bf16 data types
library/src/tensor_operation_instance/gpu/grouped_conv*d_bwd_weight/wmma/ Comprehensive WMMA device implementations for 1D, 2D, and 3D grouped convolutions
library/include/ck/library/tensor_operation_instance/gpu/ Updated header files with new WMMA instance declarations and factory methods

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.


using AccDataType = float;
float max_accumulated_value =
*std::max_element(wei_host.mData.begin(),wei_host.mData.end());
Copy link
Preview

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma in std::max_element call. Should be 'wei_host.mData.begin(), wei_host.mData.end()'.

Suggested change
*std::max_element(wei_host.mData.begin(),wei_host.mData.end());
*std::max_element(wei_host.mData.begin(), wei_host.mData.end());

Copilot uses AI. Check for mistakes.


using AccDataType = float;
float max_accumulated_value =
*std::max_element(wei_host.mData.begin(),wei_host.mData.end());
Copy link
Preview

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma in std::max_element call. Should be 'wei_host.mData.begin(), wei_host.mData.end()'.

Suggested change
*std::max_element(wei_host.mData.begin(),wei_host.mData.end());
*std::max_element(wei_host.mData.begin(), wei_host.mData.end());

Copilot uses AI. Check for mistakes.

ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
;
Copy link
Preview

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray semicolon after function call. This should be removed.

Suggested change
;

Copilot uses AI. Check for mistakes.

PassThrough,
PassThrough>>>& instances)
{

Copy link
Preview

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra blank line before add_device_operation_instances call. Should be removed for consistency.

Suggested change

Copilot uses AI. Check for mistakes.

PassThrough>>>& instances);

void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
Copy link
Preview

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Function declaration ordering is inconsistent. Two-stage functions should be grouped together after the regular instance functions for better readability and maintainability.

Copilot uses AI. Check for mistakes.

PassThrough,
PassThrough,
PassThrough>>>& instances);

Copy link
Preview

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Function declaration ordering is inconsistent. Two-stage functions should be grouped together after the regular instance functions for better readability and maintainability.

Suggested change

Copilot uses AI. Check for mistakes.

Copy link
Contributor

@bartekxk bartekxk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review will be continued

1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CShuffleBlockTransferScalarPerVector_NPerBlock
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we dont extend GNWHC layout format? Please change this example to NHWGC

ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this example to NHWGC

ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k,
num_accums / num_accums_split_k);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use default check_err for split_k ==1 because calculated threshold could hide some errors

#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such condition is not causing some register spills? I remember such issue with if constexpr inside global func

#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
using EDataType= remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants