ginkgo/core/matrix/hybrid.hpp Source File

ginkgo/core/matrix/hybrid.hpp Source File#

Reference API: ginkgo/core/matrix/hybrid.hpp Source File
Reference API
hybrid.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_MATRIX_HYBRID_HPP_
6#define GKO_PUBLIC_CORE_MATRIX_HYBRID_HPP_
7
8
9#include <algorithm>
10
11#include <ginkgo/core/base/array.hpp>
12#include <ginkgo/core/base/lin_op.hpp>
13#include <ginkgo/core/matrix/coo.hpp>
14#include <ginkgo/core/matrix/csr.hpp>
15#include <ginkgo/core/matrix/ell.hpp>
16
17
18namespace gko {
19namespace matrix {
20
21
22template <typename ValueType>
23class Dense;
24
25template <typename ValueType, typename IndexType>
26class Csr;
27
28
40template <typename ValueType = default_precision, typename IndexType = int32>
41class Hybrid
42 : public EnableLinOp<Hybrid<ValueType, IndexType>>,
43 public ConvertibleTo<Hybrid<next_precision<ValueType>, IndexType>>,
44#if GINKGO_ENABLE_HALF
45 public ConvertibleTo<
46 Hybrid<next_precision<next_precision<ValueType>>, IndexType>>,
47#endif
48 public ConvertibleTo<Dense<ValueType>>,
49 public ConvertibleTo<Csr<ValueType, IndexType>>,
50 public DiagonalExtractable<ValueType>,
51 public ReadableFromMatrixData<ValueType, IndexType>,
52 public WritableToMatrixData<ValueType, IndexType>,
54 remove_complex<Hybrid<ValueType, IndexType>>> {
56 friend class Dense<ValueType>;
57 friend class Csr<ValueType, IndexType>;
58 friend class Hybrid<to_complex<ValueType>, IndexType>;
59
60
61public:
62 using EnableLinOp<Hybrid>::convert_to;
63 using EnableLinOp<Hybrid>::move_to;
64 using ConvertibleTo<
65 Hybrid<next_precision<ValueType>, IndexType>>::convert_to;
66 using ConvertibleTo<Hybrid<next_precision<ValueType>, IndexType>>::move_to;
67 using ConvertibleTo<Dense<ValueType>>::convert_to;
68 using ConvertibleTo<Dense<ValueType>>::move_to;
71 using ReadableFromMatrixData<ValueType, IndexType>::read;
72
73 using value_type = ValueType;
74 using index_type = IndexType;
79 using absolute_type = remove_complex<Hybrid>;
80
81
92 public:
97 : ell_num_stored_elements_per_row_(zero<size_type>()),
98 coo_nnz_(zero<size_type>())
99 {}
100
114 size_type* ell_num_stored_elements_per_row,
115 size_type* coo_nnz)
116 {
117 array<size_type> ref_row_nnz(row_nnz.get_executor()->get_master(),
118 row_nnz.get_size());
119 ref_row_nnz = row_nnz;
120 ell_num_stored_elements_per_row_ =
121 this->compute_ell_num_stored_elements_per_row(&ref_row_nnz);
122 coo_nnz_ = this->compute_coo_nnz(ref_row_nnz);
123 *ell_num_stored_elements_per_row = ell_num_stored_elements_per_row_;
124 *coo_nnz = coo_nnz_;
125 }
126
133 {
134 return ell_num_stored_elements_per_row_;
135 }
136
142 size_type get_coo_nnz() const noexcept { return coo_nnz_; }
143
152 array<size_type>* row_nnz) const = 0;
153
154 protected:
164 {
165 size_type coo_nnz = 0;
166 auto row_nnz_val = row_nnz.get_const_data();
167 for (size_type i = 0; i < row_nnz.get_size(); i++) {
168 if (row_nnz_val[i] > ell_num_stored_elements_per_row_) {
169 coo_nnz +=
170 row_nnz_val[i] - ell_num_stored_elements_per_row_;
171 }
172 }
173 return coo_nnz;
174 }
175
176 private:
177 size_type ell_num_stored_elements_per_row_;
178 size_type coo_nnz_;
179 };
180
186 public:
192 explicit column_limit(size_type num_column = 0)
193 : num_columns_(num_column)
194 {}
195
197 array<size_type>* row_nnz) const override
198 {
199 return num_columns_;
200 }
201
207 auto get_num_columns() const { return num_columns_; }
208
209 private:
210 size_type num_columns_;
211 };
212
221 public:
228 explicit imbalance_limit(double percent = 0.8) : percent_(percent)
229 {
230 percent_ = std::min(percent_, 1.0);
231 percent_ = std::max(percent_, 0.0);
232 }
233
235 array<size_type>* row_nnz) const override
236 {
237 auto row_nnz_val = row_nnz->get_data();
238 auto num_rows = row_nnz->get_size();
239 if (num_rows == 0) {
240 return 0;
241 }
242 std::sort(row_nnz_val, row_nnz_val + num_rows);
243 if (percent_ < 1) {
244 auto percent_pos = static_cast<size_type>(num_rows * percent_);
245 return row_nnz_val[percent_pos];
246 } else {
247 return row_nnz_val[num_rows - 1];
248 }
249 }
250
256 auto get_percentage() const { return percent_; }
257
258 private:
259 double percent_;
260 };
261
268 public:
272 imbalance_bounded_limit(double percent = 0.8, double ratio = 0.0001)
273 : strategy_(imbalance_limit(percent)), ratio_(ratio)
274 {}
275
277 array<size_type>* row_nnz) const override
278 {
279 auto num_rows = row_nnz->get_size();
280 auto ell_cols =
282 return std::min(ell_cols,
283 static_cast<size_type>(num_rows * ratio_));
284 }
285
291 auto get_percentage() const { return strategy_.get_percentage(); }
292
298 auto get_ratio() const { return ratio_; }
299
300 private:
301 imbalance_limit strategy_;
302 double ratio_;
303 };
304
305
312 public:
317 : strategy_(
318 imbalance_limit(static_cast<double>(sizeof(IndexType)) /
319 (sizeof(ValueType) + 2 * sizeof(IndexType))))
320 {}
321
323 array<size_type>* row_nnz) const override
324 {
325 return strategy_.compute_ell_num_stored_elements_per_row(row_nnz);
326 }
327
333 auto get_percentage() const { return strategy_.get_percentage(); }
334
335 private:
336 imbalance_limit strategy_;
337 };
338
339
344 class automatic : public strategy_type {
345 public:
349 automatic() : strategy_(imbalance_bounded_limit(1.0 / 3.0, 0.001)) {}
350
352 array<size_type>* row_nnz) const override
353 {
354 return strategy_.compute_ell_num_stored_elements_per_row(row_nnz);
355 }
356
357 private:
358 imbalance_bounded_limit strategy_;
359 };
360
361 friend class Hybrid<previous_precision<ValueType>, IndexType>;
362
363 void convert_to(
364 Hybrid<next_precision<ValueType>, IndexType>* result) const override;
365
366 void move_to(Hybrid<next_precision<ValueType>, IndexType>* result) override;
367
368#if GINKGO_ENABLE_HALF
369 friend class Hybrid<previous_precision<previous_precision<ValueType>>,
370 IndexType>;
372 IndexType>>::convert_to;
373 using ConvertibleTo<
375
377 IndexType>* result) const override;
378
379 void move_to(Hybrid<next_precision<next_precision<ValueType>>, IndexType>*
380 result) override;
381#endif
382
383 void convert_to(Dense<ValueType>* other) const override;
384
385 void move_to(Dense<ValueType>* other) override;
386
387 void convert_to(Csr<ValueType, IndexType>* other) const override;
388
389 void move_to(Csr<ValueType, IndexType>* other) override;
390
391 void read(const mat_data& data) override;
392
393 void read(const device_mat_data& data) override;
394
395 void read(device_mat_data&& data) override;
396
397 void write(mat_data& data) const override;
398
399 std::unique_ptr<Diagonal<ValueType>> extract_diagonal() const override;
400
401 std::unique_ptr<absolute_type> compute_absolute() const override;
402
404
410 value_type* get_ell_values() noexcept { return ell_->get_values(); }
411
419 const value_type* get_const_ell_values() const noexcept
420 {
421 return ell_->get_const_values();
422 }
423
429 index_type* get_ell_col_idxs() noexcept { return ell_->get_col_idxs(); }
430
438 const index_type* get_const_ell_col_idxs() const noexcept
439 {
440 return ell_->get_const_col_idxs();
441 }
442
449 {
450 return ell_->get_num_stored_elements_per_row();
451 }
452
458 size_type get_ell_stride() const noexcept { return ell_->get_stride(); }
459
466 {
467 return ell_->get_num_stored_elements();
468 }
469
481 value_type& ell_val_at(size_type row, size_type idx) noexcept
482 {
483 return ell_->val_at(row, idx);
484 }
485
489 value_type ell_val_at(size_type row, size_type idx) const noexcept
490 {
491 return ell_->val_at(row, idx);
492 }
493
504 index_type& ell_col_at(size_type row, size_type idx) noexcept
505 {
506 return ell_->col_at(row, idx);
507 }
508
512 index_type ell_col_at(size_type row, size_type idx) const noexcept
513 {
514 return ell_->col_at(row, idx);
515 }
516
522 const ell_type* get_ell() const noexcept { return ell_.get(); }
523
529 value_type* get_coo_values() noexcept { return coo_->get_values(); }
530
538 const value_type* get_const_coo_values() const noexcept
539 {
540 return coo_->get_const_values();
541 }
542
548 index_type* get_coo_col_idxs() noexcept { return coo_->get_col_idxs(); }
549
557 const index_type* get_const_coo_col_idxs() const noexcept
558 {
559 return coo_->get_const_col_idxs();
560 }
561
567 index_type* get_coo_row_idxs() noexcept { return coo_->get_row_idxs(); }
568
576 const index_type* get_const_coo_row_idxs() const noexcept
577 {
578 return coo_->get_const_row_idxs();
579 }
580
587 {
588 return coo_->get_num_stored_elements();
589 }
590
596 const coo_type* get_coo() const noexcept { return coo_.get(); }
597
604 {
605 return coo_->get_num_stored_elements() +
606 ell_->get_num_stored_elements();
607 }
608
614 std::shared_ptr<strategy_type> get_strategy() const noexcept
615 {
616 return strategy_;
617 }
618
626 template <typename HybType>
627 std::shared_ptr<typename HybType::strategy_type> get_strategy() const;
628
639 static std::unique_ptr<Hybrid> create(
640 std::shared_ptr<const Executor> exec,
641 std::shared_ptr<strategy_type> strategy =
642 std::make_shared<automatic>());
643
655 static std::unique_ptr<Hybrid> create(
656 std::shared_ptr<const Executor> exec, const dim<2>& size,
657 std::shared_ptr<strategy_type> strategy =
658 std::make_shared<automatic>());
659
672 static std::unique_ptr<Hybrid> create(
673 std::shared_ptr<const Executor> exec, const dim<2>& size,
674 size_type num_stored_elements_per_row,
675 std::shared_ptr<strategy_type> strategy =
676 std::make_shared<automatic>());
677
690 static std::unique_ptr<Hybrid> create(
691 std::shared_ptr<const Executor> exec, const dim<2>& size,
692 size_type num_stored_elements_per_row, size_type stride,
693 std::shared_ptr<strategy_type> strategy);
694
708 static std::unique_ptr<Hybrid> create(
709 std::shared_ptr<const Executor> exec, const dim<2>& size,
710 size_type num_stored_elements_per_row, size_type stride,
711 size_type num_nonzeros = {},
712 std::shared_ptr<strategy_type> strategy =
713 std::make_shared<automatic>());
714
720
727
732 Hybrid(const Hybrid&);
733
740
741protected:
742 Hybrid(std::shared_ptr<const Executor> exec, const dim<2>& size = {},
743 size_type num_stored_elements_per_row = 0, size_type stride = 0,
744 size_type num_nonzeros = 0,
745 std::shared_ptr<strategy_type> strategy =
746 std::make_shared<automatic>());
747
758 void resize(dim<2> new_size, size_type ell_row_nnz, size_type coo_nnz);
759
760 void apply_impl(const LinOp* b, LinOp* x) const override;
761
762 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
763 LinOp* x) const override;
764
765private:
766 std::unique_ptr<ell_type> ell_;
767 std::unique_ptr<coo_type> coo_;
768 std::shared_ptr<strategy_type> strategy_;
769};
770
771
772template <typename ValueType, typename IndexType>
773template <typename HybType>
774std::shared_ptr<typename HybType::strategy_type>
776{
777 static_assert(
778 std::is_same<HybType, Hybrid<typename HybType::value_type,
779 typename HybType::index_type>>::value,
780 "The given `HybType` type must be of type `matrix::Hybrid`!");
781
782 std::shared_ptr<typename HybType::strategy_type> strategy;
783 if (std::dynamic_pointer_cast<automatic>(strategy_)) {
784 strategy = std::make_shared<typename HybType::automatic>();
785 } else if (auto temp = std::dynamic_pointer_cast<minimal_storage_limit>(
786 strategy_)) {
787 // minimal_storage_limit is related to ValueType and IndexType size.
788 if (sizeof(value_type) == sizeof(typename HybType::value_type) &&
789 sizeof(index_type) == sizeof(typename HybType::index_type)) {
790 strategy =
791 std::make_shared<typename HybType::minimal_storage_limit>();
792 } else {
793 strategy = std::make_shared<typename HybType::imbalance_limit>(
794 temp->get_percentage());
795 }
796 } else if (auto temp = std::dynamic_pointer_cast<imbalance_bounded_limit>(
797 strategy_)) {
798 strategy = std::make_shared<typename HybType::imbalance_bounded_limit>(
799 temp->get_percentage(), temp->get_ratio());
800 } else if (auto temp =
801 std::dynamic_pointer_cast<imbalance_limit>(strategy_)) {
802 strategy = std::make_shared<typename HybType::imbalance_limit>(
803 temp->get_percentage());
804 } else if (auto temp = std::dynamic_pointer_cast<column_limit>(strategy_)) {
805 strategy = std::make_shared<typename HybType::column_limit>(
806 temp->get_num_columns());
807 } else {
808 GKO_NOT_SUPPORTED(strategy_);
809 }
810 return strategy;
811}
812
813
814} // namespace matrix
815} // namespace gko
816
817
818#endif // GKO_PUBLIC_CORE_MATRIX_HYBRID_HPP_
Definition polymorphic_object.hpp:479
Definition lin_op.hpp:742
Definition lin_op.hpp:793
Definition lin_op.hpp:878
Definition polymorphic_object.hpp:668
Definition lin_op.hpp:117
Definition lin_op.hpp:605
Definition lin_op.hpp:660
Definition array.hpp:166
value_type * get_data() noexcept
Definition array.hpp:673
std::shared_ptr< const Executor > get_executor() const noexcept
Definition array.hpp:689
const value_type * get_const_data() const noexcept
Definition array.hpp:682
size_type get_size() const noexcept
Definition array.hpp:656
Definition device_matrix_data.hpp:36
Definition coo.hpp:61
Definition csr.hpp:123
Definition dense.hpp:116
Definition ell.hpp:63
Definition hybrid.hpp:344
size_type compute_ell_num_stored_elements_per_row(array< size_type > *row_nnz) const override
Definition hybrid.hpp:351
automatic()
Definition hybrid.hpp:349
Definition hybrid.hpp:185
size_type compute_ell_num_stored_elements_per_row(array< size_type > *row_nnz) const override
Definition hybrid.hpp:196
column_limit(size_type num_column=0)
Definition hybrid.hpp:192
auto get_num_columns() const
Definition hybrid.hpp:207
auto get_percentage() const
Definition hybrid.hpp:291
size_type compute_ell_num_stored_elements_per_row(array< size_type > *row_nnz) const override
Definition hybrid.hpp:276
imbalance_bounded_limit(double percent=0.8, double ratio=0.0001)
Definition hybrid.hpp:272
auto get_ratio() const
Definition hybrid.hpp:298
Definition hybrid.hpp:220
size_type compute_ell_num_stored_elements_per_row(array< size_type > *row_nnz) const override
Definition hybrid.hpp:234
auto get_percentage() const
Definition hybrid.hpp:256
imbalance_limit(double percent=0.8)
Definition hybrid.hpp:228
size_type compute_ell_num_stored_elements_per_row(array< size_type > *row_nnz) const override
Definition hybrid.hpp:322
auto get_percentage() const
Definition hybrid.hpp:333
minimal_storage_limit()
Definition hybrid.hpp:316
Definition hybrid.hpp:91
virtual size_type compute_ell_num_stored_elements_per_row(array< size_type > *row_nnz) const =0
strategy_type()
Definition hybrid.hpp:96
size_type compute_coo_nnz(const array< size_type > &row_nnz) const
Definition hybrid.hpp:163
size_type get_ell_num_stored_elements_per_row() const noexcept
Definition hybrid.hpp:132
void compute_hybrid_config(const array< size_type > &row_nnz, size_type *ell_num_stored_elements_per_row, size_type *coo_nnz)
Definition hybrid.hpp:113
size_type get_coo_nnz() const noexcept
Definition hybrid.hpp:142
Definition hybrid.hpp:54
size_type get_num_stored_elements() const noexcept
Definition hybrid.hpp:603
static std::unique_ptr< Hybrid > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size, size_type num_stored_elements_per_row, size_type stride, std::shared_ptr< strategy_type > strategy)
index_type * get_coo_row_idxs() noexcept
Definition hybrid.hpp:567
value_type & ell_val_at(size_type row, size_type idx) noexcept
Definition hybrid.hpp:481
static std::unique_ptr< Hybrid > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size, size_type num_stored_elements_per_row, std::shared_ptr< strategy_type > strategy=std::make_shared< automatic >())
size_type get_ell_stride() const noexcept
Definition hybrid.hpp:458
void resize(dim< 2 > new_size, size_type ell_row_nnz, size_type coo_nnz)
size_type get_coo_num_stored_elements() const noexcept
Definition hybrid.hpp:586
std::unique_ptr< absolute_type > compute_absolute() const override
index_type ell_col_at(size_type row, size_type idx) const noexcept
Definition hybrid.hpp:512
index_type * get_coo_col_idxs() noexcept
Definition hybrid.hpp:548
value_type ell_val_at(size_type row, size_type idx) const noexcept
Definition hybrid.hpp:489
const index_type * get_const_coo_row_idxs() const noexcept
Definition hybrid.hpp:576
void write(mat_data &data) const override
const value_type * get_const_ell_values() const noexcept
Definition hybrid.hpp:419
const ell_type * get_ell() const noexcept
Definition hybrid.hpp:522
std::unique_ptr< Diagonal< ValueType > > extract_diagonal() const override
value_type * get_ell_values() noexcept
Definition hybrid.hpp:410
static std::unique_ptr< Hybrid > create(std::shared_ptr< const Executor > exec, std::shared_ptr< strategy_type > strategy=std::make_shared< automatic >())
void read(const mat_data &data) override
size_type get_ell_num_stored_elements() const noexcept
Definition hybrid.hpp:465
size_type get_ell_num_stored_elements_per_row() const noexcept
Definition hybrid.hpp:448
const index_type * get_const_coo_col_idxs() const noexcept
Definition hybrid.hpp:557
Hybrid(const Hybrid &)
void read(const device_mat_data &data) override
index_type * get_ell_col_idxs() noexcept
Definition hybrid.hpp:429
void compute_absolute_inplace() override
const index_type * get_const_ell_col_idxs() const noexcept
Definition hybrid.hpp:438
std::shared_ptr< strategy_type > get_strategy() const noexcept
Definition hybrid.hpp:614
const coo_type * get_coo() const noexcept
Definition hybrid.hpp:596
static std::unique_ptr< Hybrid > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size, size_type num_stored_elements_per_row, size_type stride, size_type num_nonzeros={}, std::shared_ptr< strategy_type > strategy=std::make_shared< automatic >())
value_type * get_coo_values() noexcept
Definition hybrid.hpp:529
static std::unique_ptr< Hybrid > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size, std::shared_ptr< strategy_type > strategy=std::make_shared< automatic >())
Hybrid & operator=(const Hybrid &)
index_type & ell_col_at(size_type row, size_type idx) noexcept
Definition hybrid.hpp:504
const value_type * get_const_coo_values() const noexcept
Definition hybrid.hpp:538
void read(device_mat_data &&data) override
Hybrid & operator=(Hybrid &&)
The Ginkgo namespace.
Definition abstract_factory.hpp:20
typename detail::next_precision_impl< T >::type next_precision
Definition math.hpp:438
std::size_t size_type
Definition types.hpp:89
constexpr T zero()
Definition math.hpp:602
typename detail::to_complex_s< T >::type to_complex
Definition math.hpp:279
typename detail::remove_complex_s< T >::type remove_complex
Definition math.hpp:260
Definition dim.hpp:26
Definition matrix_data.hpp:126