Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/common/snippets/include/snippets/utils/loop_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@

#include "snippets/lowered/expression_port.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/loop_manager.hpp"

namespace ov::snippets::utils {
/**
* @brief Updates ptr_increments and finalization offsets of the provided "loop_info" based on current work amount
*/
void update_data_pointer_shifts(const ov::snippets::lowered::UnifiedLoopInfoPtr& loop_info);
void update_data_pointer_shifts(const ov::snippets::lowered::LoopManagerPtr& loop_manager,
const ov::snippets::lowered::UnifiedLoopInfoPtr& loop_info);
/**
* @brief Updates work amount and updates data pointer shifts of the provided "loop_info"
*/
void update_runtime_parameters(const ov::snippets::lowered::UnifiedLoopInfoPtr& loop_info);
void update_runtime_parameters(const ov::snippets::lowered::LoopManagerPtr& loop_manager,
const ov::snippets::lowered::UnifiedLoopInfoPtr& loop_info);
/**
* @brief Check if the passed expression port should be port of the Loop with ID `loop_id`:
* the target expression port should be connected to an expression from another Loop (missed in the loop with ID
Expand Down
7 changes: 3 additions & 4 deletions src/common/snippets/src/lowered/pass/init_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,12 @@ bool InitLoops::run(LinearIR& linear_ir) {
return false;
}

const auto& loops = linear_ir.get_loop_manager()->get_map();
for (const auto& loop : loops) {
const auto& loop_manager = linear_ir.get_loop_manager();
for (const auto& loop : loop_manager->get_map()) {
const auto& loop_info = ov::as_type_ptr<UnifiedLoopInfo>(loop.second);
update_compile_parameters(loop_info);
ov::snippets::utils::update_runtime_parameters(loop_info);
ov::snippets::utils::update_runtime_parameters(loop_manager, loop_info);
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line
}
}

bool MHAParallelWAOptimizer::run(const lowered::LinearIR& /*linear_ir*/) {
bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::MHAParallelWAOptimizer")
const auto& config = m_configurator->get_config();
size_t new_batch_dim = 0, new_kernel_dim = 0;
Expand All @@ -74,13 +74,14 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& /*linear_ir*/) {
m_configurator->update_tensor_rank(master_shape);

RuntimeConfigurator::LoopInfoRuntimeParamsMap initialized_info;
const auto& loop_manager = linear_ir.get_loop_manager();
auto updater = [&](const lowered::LoopInfoPtr& loop_info) {
if (const auto unified_loop_info = ov::as_type_ptr<lowered::UnifiedLoopInfo>(loop_info)) {
if (initialized_info.count(unified_loop_info) == 0) {
if (!ov::is_type<lowered::InnerSplittedUnifiedLoopInfo>(unified_loop_info)) {
unified_loop_info->set_work_amount(new_kernel_dim);
}
snippets::utils::update_data_pointer_shifts(unified_loop_info);
snippets::utils::update_data_pointer_shifts(loop_manager, unified_loop_info);
initialized_info[unified_loop_info] = RuntimeConfigurator::get_loop_runtime_params(unified_loop_info);
}
} else if (const auto expanded_loop_info = ov::as_type_ptr<lowered::ExpandedLoopInfo>(loop_info)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ bool SetDynamicWAToOuterMostLoop::run(LinearIR& linear_ir) {
for (const auto& loop : affected_loops) {
if (!utils::is_dynamic_value(loop->get_work_amount())) {
loop->set_work_amount(utils::get_dynamic_value<size_t>());
ov::snippets::utils::update_data_pointer_shifts(loop);
ov::snippets::utils::update_data_pointer_shifts(loop_manager, loop);
modified = true;
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/src/lowered/pass/split_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t oute

namespace {
InnerSplittedUnifiedLoopInfoPtr make_own_inner_splitted_unified_loop_info(
const LoopManagerPtr& loop_manager,
const ExpandedLoopInfoPtr& inner_expanded,
const ExpandedLoopInfoPtr& outer_expanded,
const InnerSplittedUnifiedLoopInfoPtr& existing_inner_unified) {
Expand All @@ -138,7 +139,7 @@ InnerSplittedUnifiedLoopInfoPtr make_own_inner_splitted_unified_loop_info(
existing_inner_unified->get_output_port_descs(),
existing_inner_unified->get_handlers(),
outer_expanded);
ov::snippets::utils::update_runtime_parameters(loop_info);
ov::snippets::utils::update_runtime_parameters(loop_manager, loop_info);
return loop_info;
}
ExpandedLoopInfoPtr make_own_inner_splitted_expanded_loop_info(const ExpandedLoopInfoPtr& inner_expanded,
Expand Down Expand Up @@ -196,7 +197,8 @@ bool SplitLoops::TransformInnerSplitLoop::run(LinearIR& linear_ir,
// We have to make a new UnifiedLoopInfo to distinguish it from other unified loops in other specific iterations
// of outer loop.
const auto inner_splitted_unified_loop_info =
make_own_inner_splitted_unified_loop_info(inner_expanded_loop_info,
make_own_inner_splitted_unified_loop_info(loop_manager,
inner_expanded_loop_info,
outer_loop_info,
inner_unified_loop_info);

Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,11 @@ void RuntimeConfigurator::update_expanded_loop_info(const lowered::ExpandedLoopI

void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_ir) {
LoopInfoRuntimeParamsMap initialized_info;
const auto& loop_manager = linear_ir->get_loop_manager();
auto updater = [&](const lowered::LoopInfoPtr& loop_info) {
if (const auto unified_loop_info = ov::as_type_ptr<lowered::UnifiedLoopInfo>(loop_info)) {
if (initialized_info.count(unified_loop_info) == 0) {
utils::update_runtime_parameters(unified_loop_info);
utils::update_runtime_parameters(loop_manager, unified_loop_info);
initialized_info[unified_loop_info] = get_loop_runtime_params(unified_loop_info);
}
} else if (const auto expanded_loop_info = ov::as_type_ptr<lowered::ExpandedLoopInfo>(loop_info)) {
Expand Down
96 changes: 77 additions & 19 deletions src/common/snippets/src/utils/loop_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,73 @@
#include "openvino/core/type.hpp"
#include "snippets/lowered/expression_port.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/loop_port.hpp"
#include "snippets/utils/utils.hpp"

namespace ov::snippets::utils {

using namespace ov::snippets::lowered;
namespace {
inline int64_t get_ptr_increment(const LoopPort& loop_port, size_t work_amount, size_t port_count) {
inline int64_t get_ptr_increment(const LoopInfoPtr& outer_split_info_of_nested_loop,
const LoopPort& loop_port,
size_t work_amount,
size_t port_count) {
if (!loop_port.is_incremented()) {
return 0;
}
const auto& layout = loop_port.get_expr_port()->get_descriptor_ptr()->get_layout();
const auto port_type = loop_port.get_expr_port()->get_type();
auto get_port_dim_idx = [&layout, &port_type](size_t dim_idx) {
if (port_type == ExpressionPort::Input) {
return get_input_dim_idx(layout, dim_idx);
}
if (port_type == ExpressionPort::Output) {
return get_output_dim_idx(layout, dim_idx);
}
OPENVINO_THROW("Unsupported expression port type!");
};

const auto& expr_port = loop_port.get_expr_port();
const auto& layout = expr_port->get_descriptor_ptr()->get_layout();
const auto& shape = expr_port->get_descriptor_ptr()->get_shape();
size_t dim = 0;
if (expr_port->get_type() == ExpressionPort::Input) {
dim = get_input_dim_idx(layout, loop_port.get_dim_idx());
} else if (expr_port->get_type() == ExpressionPort::Output) {
dim = get_output_dim_idx(layout, loop_port.get_dim_idx());
} else {
OPENVINO_THROW("Unsupported expression port type!");
}
const auto& original_shape = expr_port->get_descriptor_ptr()->get_shape();
const auto dim_idx = get_port_dim_idx(loop_port.get_dim_idx());
// When we cannot say about broadcasting
if (is_dynamic_value(shape[dim]) && port_count > 1) {
if (is_dynamic_value(original_shape[dim_idx]) && port_count > 1) {
return get_dynamic_value<int64_t>();
}
if (shape[dim] != 1 || work_amount == 1) {
return get_stride(dim, shape);
if (original_shape[dim_idx] != 1 || work_amount == 1) {
auto shape_for_stride_calculation = original_shape;
// Note: in case of outer split loop, we may need to stride not the whole dimension, but only block size.
// Example:
// <other loops>
// | LoopBegin (outer_split): wa = n, inc = n_blk, dim_idx = 0
// | | ...
// | | Buffer_in [m_blk x n_blk]
// | | LoopBegin (cur_loop): wa = m_blk, inc = 1, dim_idx = 1
// | | | Load
// | | | ...
// | | | Store
// | | LoopEnd (cur_loop)
// | LoopEnd (outer_split)
// | Buffer_out [m_blk x n]
// <other loops>
// -------------
// In this case, cur_loop ptr increments must be the following:
// - Load from Buffer_in: ptr_increment = n_blk, since this loop port is inside outer_split loop
// - Store to Buffer_out: ptr_increment = n, since this loop port is outside outer_split loop
if (outer_split_info_of_nested_loop != nullptr) {
const auto& ports = port_type == ExpressionPort::Input
? outer_split_info_of_nested_loop->get_input_ports()
: outer_split_info_of_nested_loop->get_output_ports();
auto it = std::find_if(ports.cbegin(), ports.cend(), [&expr_port](const LoopPort& lp) {
return *lp.get_expr_port() == *expr_port;
});
if (it == ports.cend()) {
const auto shape_dim_idx = get_port_dim_idx(outer_split_info_of_nested_loop->get_dim_idx());
shape_for_stride_calculation[shape_dim_idx] = outer_split_info_of_nested_loop->get_increment();
}
}
return get_stride(dim_idx, shape_for_stride_calculation);
}
return 0;
}
Expand Down Expand Up @@ -73,15 +111,35 @@ inline void init_work_amount(const LoopInfoPtr& loop_info) {
}
} // namespace

void update_data_pointer_shifts(const UnifiedLoopInfoPtr& loop_info) {
void update_data_pointer_shifts(const LoopManagerPtr& loop_manager, const UnifiedLoopInfoPtr& loop_info) {
OPENVINO_ASSERT(loop_info != nullptr, "UnifiedLoopInfo is nullptr, nothing to update");
const auto work_amount = loop_info->get_work_amount();
const auto input_count = loop_info->get_input_count();
const auto output_count = loop_info->get_output_count();

auto update_shifts = [&work_amount, &input_count, &output_count](LoopPort& loop_port,
UnifiedLoopInfo::LoopPortDesc& ptr_shifts_params) {
// WA: to find outer split loop whose dim_idx is less than cur_dim_idx,
// we use the knowledge that such outer loop is connected with the inner split loop
// which is nested inside the current loop
// TODO: this logic must be reworked, and WA should be removed, when blocking shapes are supported
// Ticket: 155651
LoopInfoPtr outer_split_info_of_nested_loop = nullptr;
if (auto cur_dim_idx = loop_info->get_dim_idx(); cur_dim_idx != LoopPort::UNDEFINED_DIM_IDX) {
auto fst_port_expr = loop_info->get_input_ports().front().get_expr_port()->get_expr();
for (const auto loop_idx : fst_port_expr->get_loop_ids()) {
const auto loop_info = loop_manager->get_loop_info(loop_idx);
if (const auto inner_split_loop = ov::as_type_ptr<InnerSplittedUnifiedLoopInfo>(loop_info)) {
if (inner_split_loop->get_dim_idx() < cur_dim_idx) {
OPENVINO_ASSERT(outer_split_info_of_nested_loop == nullptr,
"only 1 nested inner split loop is supported");
outer_split_info_of_nested_loop = inner_split_loop->get_outer_splitted_loop_info();
}
}
}
}

auto update_shifts = [&](LoopPort& loop_port, UnifiedLoopInfo::LoopPortDesc& ptr_shifts_params) {
ptr_shifts_params.ptr_increment = get_ptr_increment(
outer_split_info_of_nested_loop,
loop_port,
work_amount,
loop_port.get_expr_port()->get_type() == ExpressionPort::Input ? input_count : output_count);
Expand All @@ -90,12 +148,12 @@ void update_data_pointer_shifts(const UnifiedLoopInfoPtr& loop_info) {
loop_info->iterate_through_infos(update_shifts);
}

void update_runtime_parameters(const UnifiedLoopInfoPtr& loop_info) {
void update_runtime_parameters(const LoopManagerPtr& loop_manager, const UnifiedLoopInfoPtr& loop_info) {
OPENVINO_ASSERT(loop_info != nullptr, "UnifiedLoopInfo is nullptr, nothing to update");
if (!ov::is_type<InnerSplittedUnifiedLoopInfo>(loop_info)) {
init_work_amount(loop_info);
}
update_data_pointer_shifts(loop_info);
update_data_pointer_shifts(loop_manager, loop_info);
}

bool should_be_loop_port(const ov::snippets::lowered::ExpressionPort& port, size_t loop_id) {
Expand Down
6 changes: 5 additions & 1 deletion src/common/snippets/tests/include/lir_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ void init_expr_descriptors(const ov::snippets::lowered::ExpressionPtr& expr,
const std::vector<ov::snippets::VectorDims>& subtensors = {},
const std::vector<ov::snippets::VectorDims>& layouts = {});

using IOLoopPortDescs = std::pair<std::vector<ov::snippets::lowered::UnifiedLoopInfo::LoopPortDesc>,
std::vector<ov::snippets::lowered::UnifiedLoopInfo::LoopPortDesc>>;
/**
* @brief Creates an InnerSplittedUnifiedLoopInfo which represents an inner loop that appears
* after SplitLoops optimizations.
Expand All @@ -77,14 +79,16 @@ void init_expr_descriptors(const ov::snippets::lowered::ExpressionPtr& expr,
* @param entries Vector of LoopPort objects representing loop entry points (input ports)
* @param exits Vector of LoopPort objects representing loop exit points (output ports)
* @param outer_split_loop_info Pointer to the outer split loop info that will contain this inner loop
* @param io_descs Optional parameter containing input and output port descriptors
* @return Shared pointer to the created InnerSplittedUnifiedLoopInfo
*/
ov::snippets::lowered::InnerSplittedUnifiedLoopInfoPtr make_inner_split_loop_info(
size_t work_amount,
size_t increment,
const std::vector<ov::snippets::lowered::LoopPort>& entries,
const std::vector<ov::snippets::lowered::LoopPort>& exits,
const ov::snippets::lowered::UnifiedLoopInfoPtr& outer_split_loop_info);
const ov::snippets::lowered::UnifiedLoopInfoPtr& outer_split_loop_info,
const std::optional<IOLoopPortDescs>& io_descs = std::nullopt);

} // namespace snippets
} // namespace test
Expand Down
15 changes: 12 additions & 3 deletions src/common/snippets/tests/src/lir_test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,23 @@ InnerSplittedUnifiedLoopInfoPtr make_inner_split_loop_info(size_t work_amount,
size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
const UnifiedLoopInfoPtr& outer_split_loop_info) {
const UnifiedLoopInfoPtr& outer_split_loop_info,
const std::optional<IOLoopPortDescs>& io_descs) {
outer_split_loop_info
->register_pass_to_handler<SpecificLoopIterType::MAIN_BODY, SplitLoops::TransformInnerSplitLoop>();
outer_split_loop_info
->register_pass_to_handler<SpecificLoopIterType::LAST_ITER, SplitLoops::TransformInnerSplitLoop>();
// Note: this temporary loop is needed to easily create InnerSplittedUnifiedLoopInfo:
// we extract all automatically calculated parameters from it such as LoopPortDesc and SpecificIterationHandlers
const auto tmp_unified_loop = std::make_shared<UnifiedLoopInfo>(work_amount, increment, entries, exits, false);
// we extract all automatically calculated parameters from it such as SpecificIterationHandlers
const auto tmp_unified_loop =
io_descs.has_value() ? std::make_shared<UnifiedLoopInfo>(work_amount,
increment,
entries,
exits,
io_descs.value().first,
io_descs.value().second,
false)
: std::make_shared<UnifiedLoopInfo>(work_amount, increment, entries, exits, false);
return std::make_shared<InnerSplittedUnifiedLoopInfo>(tmp_unified_loop->get_increment(),
tmp_unified_loop->get_input_ports(),
tmp_unified_loop->get_output_ports(),
Expand Down
Loading
Loading