ginkgo/core/distributed/matrix.hpp Source File

ginkgo/core/distributed/matrix.hpp Source File#

Reference API: ginkgo/core/distributed/matrix.hpp Source File
Reference API
matrix.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
6#define GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
7
8
9#include <ginkgo/config.hpp>
10
11
12#if GINKGO_BUILD_MPI
13
14
15#include <ginkgo/core/base/dense_cache.hpp>
16#include <ginkgo/core/base/lin_op.hpp>
17#include <ginkgo/core/base/mpi.hpp>
18#include <ginkgo/core/base/std_extensions.hpp>
19#include <ginkgo/core/distributed/base.hpp>
20#include <ginkgo/core/distributed/index_map.hpp>
21
22
23namespace gko {
24namespace matrix {
25
26
27template <typename ValueType, typename IndexType>
28class Csr;
29
30
31}
32
33
34namespace multigrid {
35
36
37template <typename ValueType, typename IndexType>
38class Pgm;
39
40
41}
42
43
44namespace detail {
45
46
51template <typename Builder, typename ValueType, typename IndexType,
52 typename = void>
53struct is_matrix_type_builder : std::false_type {};
54
55
56template <typename Builder, typename ValueType, typename IndexType>
57struct is_matrix_type_builder<
58 Builder, ValueType, IndexType,
59 xstd::void_t<
60 decltype(std::declval<Builder>().template create<ValueType, IndexType>(
61 std::declval<std::shared_ptr<const Executor>>()))>>
62 : std::true_type {};
63
64
65template <template <typename, typename> class MatrixType,
66 typename... CreateArgs>
67struct MatrixTypeBuilderFromValueAndIndex {
68 template <typename ValueType, typename IndexType, std::size_t... I>
69 auto create_impl(std::shared_ptr<const Executor> exec,
70 std::index_sequence<I...>)
71 {
72 return MatrixType<ValueType, IndexType>::create(
73 exec, std::get<I>(create_args)...);
74 }
75
76
77 template <typename ValueType, typename IndexType>
78 auto create(std::shared_ptr<const Executor> exec)
79 {
80 // with c++17 we could use std::apply
81 static constexpr auto size = sizeof...(CreateArgs);
82 return create_impl<ValueType, IndexType>(
83 std::move(exec), std::make_index_sequence<size>{});
84 }
85
86 std::tuple<CreateArgs...> create_args;
87};
88
89
90} // namespace detail
91
92
124template <template <typename, typename> class MatrixType, typename... Args>
125auto with_matrix_type(Args&&... create_args)
126{
127 return detail::MatrixTypeBuilderFromValueAndIndex<MatrixType, Args...>{
128 std::forward_as_tuple(create_args...)};
129}
130
131
132namespace experimental {
133namespace distributed {
134
135
145enum class assembly_mode { communicate, local_only };
146
147
148template <typename LocalIndexType, typename GlobalIndexType>
149class Partition;
150template <typename ValueType>
151class Vector;
152
153
260template <typename ValueType = default_precision,
261 typename LocalIndexType = int32, typename GlobalIndexType = int64>
263 : public EnableLinOp<Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
264 public ConvertibleTo<
265 Matrix<next_precision<ValueType>, LocalIndexType, GlobalIndexType>>,
266#if GINKGO_ENABLE_HALF
267 public ConvertibleTo<Matrix<next_precision<next_precision<ValueType>>,
268 LocalIndexType, GlobalIndexType>>,
269#endif
270 public DistributedBase {
271 friend class EnablePolymorphicObject<Matrix, LinOp>;
272 friend class Matrix<previous_precision<ValueType>, LocalIndexType,
273 GlobalIndexType>;
274 friend class multigrid::Pgm<ValueType, LocalIndexType>;
275
276
277public:
278 using value_type = ValueType;
279 using index_type = GlobalIndexType;
280 using local_index_type = LocalIndexType;
281 using global_index_type = GlobalIndexType;
282 using global_vector_type =
284 using local_vector_type = typename global_vector_type::local_vector_type;
285
286 using EnableLinOp<Matrix>::convert_to;
287 using EnableLinOp<Matrix>::move_to;
289 GlobalIndexType>>::convert_to;
291 GlobalIndexType>>::move_to;
292
293 void convert_to(Matrix<next_precision<value_type>, local_index_type,
294 global_index_type>* result) const override;
295
296 void move_to(Matrix<next_precision<value_type>, local_index_type,
297 global_index_type>* result) override;
298#if GINKGO_ENABLE_HALF
299 friend class Matrix<previous_precision<previous_precision<ValueType>>,
300 LocalIndexType, GlobalIndexType>;
301 using ConvertibleTo<
303 global_index_type>>::convert_to;
305 local_index_type, global_index_type>>::move_to;
306
307 void convert_to(
309 global_index_type>* result) const override;
310
312 local_index_type, global_index_type>* result) override;
313
314#endif
335 partition,
336 assembly_mode assembly_type = assembly_mode::local_only);
337
350 partition,
351 assembly_mode assembly_type = assembly_mode::local_only);
352
374 row_partition,
376 col_partition,
377 assembly_mode assembly_type = assembly_mode::local_only);
378
391 row_partition,
393 col_partition,
394 assembly_mode assembly_type = assembly_mode::local_only);
395
401 std::shared_ptr<const LinOp> get_local_matrix() const { return local_mtx_; }
402
408 std::shared_ptr<const LinOp> get_non_local_matrix() const
409 {
410 return non_local_mtx_;
411 }
412
418 Matrix(const Matrix& other);
419
425 Matrix(Matrix&& other) noexcept;
426
435 Matrix& operator=(const Matrix& other);
436
446
456 static std::unique_ptr<Matrix> create(std::shared_ptr<const Executor> exec,
457 mpi::communicator comm);
458
479 template <typename MatrixType,
480 typename = std::enable_if_t<gko::detail::is_matrix_type_builder<
481 MatrixType, ValueType, LocalIndexType>::value>>
482 static std::unique_ptr<Matrix> create(std::shared_ptr<const Executor> exec,
484 MatrixType matrix_template)
485 {
486 return create(
487 exec, comm,
488 matrix_template.template create<ValueType, LocalIndexType>(exec));
489 }
490
519 template <typename LocalMatrixType, typename NonLocalMatrixType,
520 typename = std::enable_if_t<
521 gko::detail::is_matrix_type_builder<
522 LocalMatrixType, ValueType, LocalIndexType>::value &&
523 gko::detail::is_matrix_type_builder<
524 NonLocalMatrixType, ValueType, LocalIndexType>::value>>
525 static std::unique_ptr<Matrix> create(
526 std::shared_ptr<const Executor> exec, mpi::communicator comm,
527 LocalMatrixType local_matrix_template,
528 NonLocalMatrixType non_local_matrix_template)
529 {
530 return create(
531 exec, comm,
532 local_matrix_template.template create<ValueType, LocalIndexType>(
533 exec),
534 non_local_matrix_template
535 .template create<ValueType, LocalIndexType>(exec));
536 }
537
552 static std::unique_ptr<Matrix> create(
553 std::shared_ptr<const Executor> exec, mpi::communicator comm,
554 ptr_param<const LinOp> matrix_template);
555
572 static std::unique_ptr<Matrix> create(
573 std::shared_ptr<const Executor> exec, mpi::communicator comm,
574 ptr_param<const LinOp> local_matrix_template,
575 ptr_param<const LinOp> non_local_matrix_template);
576
589 static std::unique_ptr<Matrix> create(std::shared_ptr<const Executor> exec,
590 mpi::communicator comm, dim<2> size,
591 std::shared_ptr<LinOp> local_linop);
592
611 static std::unique_ptr<Matrix> create(
612 std::shared_ptr<const Executor> exec, mpi::communicator comm,
613 dim<2> size, std::shared_ptr<LinOp> local_linop,
614 std::shared_ptr<LinOp> non_local_linop,
615 std::vector<comm_index_type> recv_sizes,
616 std::vector<comm_index_type> recv_offsets,
617 array<local_index_type> recv_gather_idxs);
618
627
636
637protected:
638 explicit Matrix(std::shared_ptr<const Executor> exec,
639 mpi::communicator comm);
640
641 explicit Matrix(std::shared_ptr<const Executor> exec,
643 ptr_param<const LinOp> local_matrix_template,
644 ptr_param<const LinOp> non_local_matrix_template);
645
646 explicit Matrix(std::shared_ptr<const Executor> exec,
647 mpi::communicator comm, dim<2> size,
648 std::shared_ptr<LinOp> local_linop);
649
650 explicit Matrix(std::shared_ptr<const Executor> exec,
651 mpi::communicator comm, dim<2> size,
652 std::shared_ptr<LinOp> local_linop,
653 std::shared_ptr<LinOp> non_local_linop,
654 std::vector<comm_index_type> recv_sizes,
655 std::vector<comm_index_type> recv_offsets,
656 array<local_index_type> recv_gather_idxs);
657
666 mpi::request communicate(const local_vector_type* local_b) const;
667
668 void apply_impl(const LinOp* b, LinOp* x) const override;
669
670 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
671 LinOp* x) const override;
672
673private:
674 std::vector<comm_index_type> send_offsets_;
675 std::vector<comm_index_type> send_sizes_;
676 std::vector<comm_index_type> recv_offsets_;
677 std::vector<comm_index_type> recv_sizes_;
678 array<local_index_type> gather_idxs_;
679 array<global_index_type> non_local_to_global_;
680 gko::detail::DenseCache<value_type> one_scalar_;
681 gko::detail::DenseCache<value_type> host_send_buffer_;
682 gko::detail::DenseCache<value_type> host_recv_buffer_;
683 gko::detail::DenseCache<value_type> send_buffer_;
684 gko::detail::DenseCache<value_type> recv_buffer_;
685 std::shared_ptr<LinOp> local_mtx_;
686 std::shared_ptr<LinOp> non_local_mtx_;
687};
688
689
690} // namespace distributed
691} // namespace experimental
692} // namespace gko
693
694
695#endif
696
697
698#endif // GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
Definition polymorphic_object.hpp:479
Definition lin_op.hpp:878
Definition polymorphic_object.hpp:668
Definition lin_op.hpp:117
Definition array.hpp:166
Definition device_matrix_data.hpp:36
std::shared_ptr< const LinOp > get_non_local_matrix() const
Definition matrix.hpp:408
Matrix(Matrix &&other) noexcept
void read_distributed(const matrix_data< value_type, global_index_type > &data, std::shared_ptr< const Partition< local_index_type, global_index_type > > row_partition, std::shared_ptr< const Partition< local_index_type, global_index_type > > col_partition, assembly_mode assembly_type=assembly_mode::local_only)
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, ptr_param< const LinOp > local_matrix_template, ptr_param< const LinOp > non_local_matrix_template)
void col_scale(ptr_param< const global_vector_type > scaling_factors)
std::shared_ptr< const LinOp > get_local_matrix() const
Definition matrix.hpp:401
void read_distributed(const device_matrix_data< value_type, global_index_type > &data, std::shared_ptr< const Partition< local_index_type, global_index_type > > row_partition, std::shared_ptr< const Partition< local_index_type, global_index_type > > col_partition, assembly_mode assembly_type=assembly_mode::local_only)
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm)
void read_distributed(const matrix_data< value_type, global_index_type > &data, std::shared_ptr< const Partition< local_index_type, global_index_type > > partition, assembly_mode assembly_type=assembly_mode::local_only)
void row_scale(ptr_param< const global_vector_type > scaling_factors)
mpi::request communicate(const local_vector_type *local_b) const
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, dim< 2 > size, std::shared_ptr< LinOp > local_linop)
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, dim< 2 > size, std::shared_ptr< LinOp > local_linop, std::shared_ptr< LinOp > non_local_linop, std::vector< comm_index_type > recv_sizes, std::vector< comm_index_type > recv_offsets, array< local_index_type > recv_gather_idxs)
void read_distributed(const device_matrix_data< value_type, global_index_type > &data, std::shared_ptr< const Partition< local_index_type, global_index_type > > partition, assembly_mode assembly_type=assembly_mode::local_only)
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, LocalMatrixType local_matrix_template, NonLocalMatrixType non_local_matrix_template)
Definition matrix.hpp:525
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, MatrixType matrix_template)
Definition matrix.hpp:482
Matrix & operator=(Matrix &&other)
Matrix & operator=(const Matrix &other)
static std::unique_ptr< Matrix > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, ptr_param< const LinOp > matrix_template)
Definition mpi.hpp:327
Definition pgm.hpp:52
Definition utils_helper.hpp:41
assembly_mode
Definition matrix.hpp:145
The Ginkgo namespace.
Definition abstract_factory.hpp:20
double default_precision
Definition types.hpp:171
std::int32_t int32
Definition types.hpp:106
auto with_matrix_type(Args &&... create_args)
Definition matrix.hpp:125
typename detail::next_precision_impl< T >::type next_precision
Definition math.hpp:438
std::int64_t int64
Definition types.hpp:112
Definition dim.hpp:26
Definition matrix_data.hpp:126