ginkgo/core/distributed/vector.hpp Source File

ginkgo/core/distributed/vector.hpp Source File#

Reference API: ginkgo/core/distributed/vector.hpp Source File
Reference API
vector.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
6#define GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
7
8
9#include <ginkgo/config.hpp>
10
11
12#if GINKGO_BUILD_MPI
13
14
15#include <ginkgo/core/base/dense_cache.hpp>
16#include <ginkgo/core/base/lin_op.hpp>
17#include <ginkgo/core/base/mpi.hpp>
18#include <ginkgo/core/distributed/base.hpp>
19#include <ginkgo/core/matrix/dense.hpp>
20
21
22namespace gko {
23namespace experimental {
24namespace distributed {
25namespace detail {
26
27
28template <typename ValueType>
29class VectorCache;
30
31
32} // namespace detail
33
34
35template <typename LocalIndexType, typename GlobalIndexType>
36class Partition;
37
38
65template <typename ValueType = double>
66class Vector
67 : public EnableLinOp<Vector<ValueType>>,
68 public ConvertibleTo<Vector<next_precision<ValueType>>>,
69#if GINKGO_ENABLE_HALF
70 public ConvertibleTo<Vector<next_precision<next_precision<ValueType>>>>,
71#endif
72 public EnableAbsoluteComputation<remove_complex<Vector<ValueType>>>,
73 public DistributedBase {
75 friend class Vector<to_complex<ValueType>>;
76 friend class Vector<remove_complex<ValueType>>;
77 friend class Vector<previous_precision<ValueType>>;
78 friend class detail::VectorCache<ValueType>;
79
80public:
81 using EnableLinOp<Vector>::convert_to;
82 using EnableLinOp<Vector>::move_to;
85
86 using value_type = ValueType;
87 using absolute_type = remove_complex<Vector>;
88 using real_type = absolute_type;
91
98 static std::unique_ptr<Vector> create_with_config_of(
100
101
113 static std::unique_ptr<Vector> create_with_type_of(
114 ptr_param<const Vector> other, std::shared_ptr<const Executor> exec);
115
128 static std::unique_ptr<Vector> create_with_type_of(
129 ptr_param<const Vector> other, std::shared_ptr<const Executor> exec,
130 const dim<2>& global_size, const dim<2>& local_size, size_type stride);
131
147 ptr_param<const Partition<int64, int64>> partition);
148
150 ptr_param<const Partition<int32, int64>> partition);
151
153 ptr_param<const Partition<int32, int32>> partition);
154
165 ptr_param<const Partition<int64, int64>> partition);
166
168 ptr_param<const Partition<int32, int64>> partition);
169
171 ptr_param<const Partition<int32, int32>> partition);
172
173 void convert_to(Vector<next_precision<ValueType>>* result) const override;
174
175 void move_to(Vector<next_precision<ValueType>>* result) override;
176
177#if GINKGO_ENABLE_HALF
178 friend class Vector<previous_precision<previous_precision<ValueType>>>;
179 using ConvertibleTo<
181 using ConvertibleTo<
183
184 void convert_to(Vector<next_precision<next_precision<ValueType>>>* result)
185 const override;
186
187 void move_to(
189#endif
190
191 std::unique_ptr<absolute_type> compute_absolute() const override;
192
194
199 std::unique_ptr<complex_type> make_complex() const;
200
207
212 std::unique_ptr<real_type> get_real() const;
213
217 void get_real(ptr_param<real_type> result) const;
218
223 std::unique_ptr<real_type> get_imag() const;
224
229 void get_imag(ptr_param<real_type> result) const;
230
236 void fill(ValueType value);
237
248
259
270
280
291
305 array<char>& tmp) const;
306
317 ptr_param<LinOp> result) const;
318
332 array<char>& tmp) const;
333
343
356
365 void compute_norm2(ptr_param<LinOp> result) const;
366
379
387 void compute_norm1(ptr_param<LinOp> result) const;
388
401
410 void compute_mean(ptr_param<LinOp> result) const;
411
423 void compute_mean(ptr_param<LinOp> result, array<char>& tmp) const;
424
435 value_type& at_local(size_type row, size_type col) noexcept;
436
440 value_type at_local(size_type row, size_type col) const noexcept;
441
456 ValueType& at_local(size_type idx) noexcept;
457
461 ValueType at_local(size_type idx) const noexcept;
462
468 value_type* get_local_values();
469
477 const value_type* get_const_local_values() const;
478
485
493 std::unique_ptr<const real_type> create_real_view() const;
494
498 std::unique_ptr<real_type> create_real_view();
499
500 size_type get_stride() const noexcept { return local_.get_stride(); }
501
513 static std::unique_ptr<Vector> create(std::shared_ptr<const Executor> exec,
515 dim<2> global_size, dim<2> local_size,
516 size_type stride);
517
529 static std::unique_ptr<Vector> create(std::shared_ptr<const Executor> exec,
531 dim<2> global_size = {},
532 dim<2> local_size = {});
533
551 static std::unique_ptr<Vector> create(
552 std::shared_ptr<const Executor> exec, mpi::communicator comm,
553 dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);
554
573 static std::unique_ptr<Vector> create(
574 std::shared_ptr<const Executor> exec, mpi::communicator comm,
575 std::unique_ptr<local_vector_type> local_vector);
576
589 static std::unique_ptr<const Vector> create_const(
590 std::shared_ptr<const Executor> exec, mpi::communicator comm,
591 dim<2> global_size,
592 std::unique_ptr<const local_vector_type> local_vector);
593
606 static std::unique_ptr<const Vector> create_const(
607 std::shared_ptr<const Executor> exec, mpi::communicator comm,
608 std::unique_ptr<const local_vector_type> local_vector);
609
610protected:
611 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
612 dim<2> global_size, dim<2> local_size, size_type stride);
613
614 explicit Vector(std::shared_ptr<const Executor> exec,
615 mpi::communicator comm, dim<2> global_size = {},
616 dim<2> local_size = {});
617
618 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
619 dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);
620
621 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
622 std::unique_ptr<local_vector_type> local_vector);
623
624 void resize(dim<2> global_size, dim<2> local_size);
625
626 template <typename LocalIndexType, typename GlobalIndexType>
627 void read_distributed_impl(
629 const Partition<LocalIndexType, GlobalIndexType>* partition);
630
631 void apply_impl(const LinOp*, LinOp*) const override;
632
633 void apply_impl(const LinOp*, const LinOp*, const LinOp*,
634 LinOp*) const override;
635
642 virtual std::unique_ptr<Vector> create_with_same_config() const;
643
656 virtual std::unique_ptr<Vector> create_with_type_of_impl(
657 std::shared_ptr<const Executor> exec, const dim<2>& global_size,
658 const dim<2>& local_size, size_type stride) const;
659
660private:
661 local_vector_type local_;
662 ::gko::detail::DenseCache<ValueType> host_reduction_buffer_;
663 ::gko::detail::DenseCache<remove_complex<ValueType>> host_norm_buffer_;
664};
665
666
667} // namespace distributed
668} // namespace experimental
669
670
671namespace detail {
672
673
674template <typename TargetType>
675struct conversion_target_helper;
676
677
687template <typename ValueType>
688struct conversion_target_helper<experimental::distributed::Vector<ValueType>> {
690 using source_type =
692
693 static std::unique_ptr<target_type> create_empty(const source_type* source)
694 {
695 return target_type::create(source->get_executor(),
696 source->get_communicator());
697 }
698
699 // Allow to create_empty of the same type
700 // For distributed case, next<next<V>> will be V in the candidate list.
701 // TODO: decide to whether to add this or add condition to the list
702 static std::unique_ptr<target_type> create_empty(const target_type* source)
703 {
704 return target_type::create(source->get_executor(),
705 source->get_communicator());
706 }
707
708#if GINKGO_ENABLE_HALF
709 using snd_source_type = experimental::distributed::Vector<
710 previous_precision<previous_precision<ValueType>>>;
711
712 static std::unique_ptr<target_type> create_empty(
713 const snd_source_type* source)
714 {
715 return target_type::create(source->get_executor(),
716 source->get_communicator());
717 }
718#endif
719};
720
721
722} // namespace detail
723} // namespace gko
724
725
726#endif // GINKGO_BUILD_MPI
727
728
729#endif // GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
Definition polymorphic_object.hpp:479
Definition lin_op.hpp:793
Definition lin_op.hpp:878
Definition polymorphic_object.hpp:668
Definition lin_op.hpp:117
Definition array.hpp:166
Definition device_matrix_data.hpp:36
value_type at_local(size_type row, size_type col) const noexcept
void compute_mean(ptr_param< LinOp > result) const
static std::unique_ptr< Vector > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, dim< 2 > global_size, dim< 2 > local_size, size_type stride)
void compute_norm2(ptr_param< LinOp > result) const
virtual std::unique_ptr< Vector > create_with_type_of_impl(std::shared_ptr< const Executor > exec, const dim< 2 > &global_size, const dim< 2 > &local_size, size_type stride) const
void read_distributed(const matrix_data< ValueType, int64 > &data, ptr_param< const Partition< int64, int64 > > partition)
void make_complex(ptr_param< complex_type > result) const
std::unique_ptr< real_type > create_real_view()
void compute_squared_norm2(ptr_param< LinOp > result, array< char > &tmp) const
static std::unique_ptr< Vector > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, dim< 2 > global_size={}, dim< 2 > local_size={})
std::unique_ptr< real_type > get_real() const
static std::unique_ptr< const Vector > create_const(std::shared_ptr< const Executor > exec, mpi::communicator comm, dim< 2 > global_size, std::unique_ptr< const local_vector_type > local_vector)
std::unique_ptr< const real_type > create_real_view() const
static std::unique_ptr< Vector > create_with_type_of(ptr_param< const Vector > other, std::shared_ptr< const Executor > exec, const dim< 2 > &global_size, const dim< 2 > &local_size, size_type stride)
static std::unique_ptr< Vector > create_with_config_of(ptr_param< const Vector > other)
value_type & at_local(size_type row, size_type col) noexcept
void compute_norm2(ptr_param< LinOp > result, array< char > &tmp) const
const value_type * get_const_local_values() const
void compute_conj_dot(ptr_param< const LinOp > b, ptr_param< LinOp > result) const
virtual std::unique_ptr< Vector > create_with_same_config() const
void get_real(ptr_param< real_type > result) const
void compute_squared_norm2(ptr_param< LinOp > result) const
void sub_scaled(ptr_param< const LinOp > alpha, ptr_param< const LinOp > b)
void compute_dot(ptr_param< const LinOp > b, ptr_param< LinOp > result) const
static std::unique_ptr< const Vector > create_const(std::shared_ptr< const Executor > exec, mpi::communicator comm, std::unique_ptr< const local_vector_type > local_vector)
void compute_mean(ptr_param< LinOp > result, array< char > &tmp) const
std::unique_ptr< complex_type > make_complex() const
void compute_norm1(ptr_param< LinOp > result) const
static std::unique_ptr< Vector > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, dim< 2 > global_size, std::unique_ptr< local_vector_type > local_vector)
void get_imag(ptr_param< real_type > result) const
std::unique_ptr< absolute_type > compute_absolute() const override
void compute_dot(ptr_param< const LinOp > b, ptr_param< LinOp > result, array< char > &tmp) const
std::unique_ptr< real_type > get_imag() const
void inv_scale(ptr_param< const LinOp > alpha)
void add_scaled(ptr_param< const LinOp > alpha, ptr_param< const LinOp > b)
void scale(ptr_param< const LinOp > alpha)
void read_distributed(const device_matrix_data< ValueType, int64 > &data, ptr_param< const Partition< int64, int64 > > partition)
const local_vector_type * get_local_vector() const
static std::unique_ptr< Vector > create_with_type_of(ptr_param< const Vector > other, std::shared_ptr< const Executor > exec)
ValueType & at_local(size_type idx) noexcept
void compute_norm1(ptr_param< LinOp > result, array< char > &tmp) const
ValueType at_local(size_type idx) const noexcept
static std::unique_ptr< Vector > create(std::shared_ptr< const Executor > exec, mpi::communicator comm, std::unique_ptr< local_vector_type > local_vector)
void compute_conj_dot(ptr_param< const LinOp > b, ptr_param< LinOp > result, array< char > &tmp) const
size_type get_stride() const noexcept
Definition dense.hpp:869
Definition utils_helper.hpp:41
The Ginkgo namespace.
Definition abstract_factory.hpp:20
typename detail::next_precision_impl< T >::type next_precision
Definition math.hpp:438
std::size_t size_type
Definition types.hpp:89
typename detail::to_complex_s< T >::type to_complex
Definition math.hpp:279
typename detail::remove_complex_s< T >::type remove_complex
Definition math.hpp:260
Definition dim.hpp:26
Definition matrix_data.hpp:126