5#ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6#define GKO_PUBLIC_CORE_BASE_MPI_HPP_
13#include <ginkgo/config.hpp>
14#include <ginkgo/core/base/exception.hpp>
15#include <ginkgo/core/base/exception_helpers.hpp>
16#include <ginkgo/core/base/executor.hpp>
17#include <ginkgo/core/base/half.hpp>
18#include <ginkgo/core/base/types.hpp>
19#include <ginkgo/core/base/utils_helper.hpp>
29namespace experimental {
44#if GINKGO_HAVE_GPU_AWARE_MPI
62#define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
64 struct type_impl<input_type> { \
65 static MPI_Datatype get_type() { return mpi_type; } \
80GKO_REGISTER_MPI_TYPE(
char, MPI_CHAR);
81GKO_REGISTER_MPI_TYPE(
unsigned char, MPI_UNSIGNED_CHAR);
82GKO_REGISTER_MPI_TYPE(
unsigned, MPI_UNSIGNED);
83GKO_REGISTER_MPI_TYPE(
int, MPI_INT);
84GKO_REGISTER_MPI_TYPE(
unsigned short, MPI_UNSIGNED_SHORT);
85GKO_REGISTER_MPI_TYPE(
unsigned long, MPI_UNSIGNED_LONG);
86GKO_REGISTER_MPI_TYPE(
long, MPI_LONG);
87GKO_REGISTER_MPI_TYPE(
long long, MPI_LONG_LONG_INT);
88GKO_REGISTER_MPI_TYPE(
unsigned long long, MPI_UNSIGNED_LONG_LONG);
89GKO_REGISTER_MPI_TYPE(
float, MPI_FLOAT);
90GKO_REGISTER_MPI_TYPE(
double, MPI_DOUBLE);
91GKO_REGISTER_MPI_TYPE(
long double, MPI_LONG_DOUBLE);
96GKO_REGISTER_MPI_TYPE(
half, MPI_UNSIGNED_SHORT);
97GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
99GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
100GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
119 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
120 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
145 *
this = std::move(other);
157 if (
this != &other) {
158 this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
168 if (type_ != MPI_DATATYPE_NULL) {
169 MPI_Type_free(&type_);
178 MPI_Datatype
get()
const {
return type_; }
190 serialized = MPI_THREAD_SERIALIZED,
191 funneled = MPI_THREAD_FUNNELED,
192 single = MPI_THREAD_SINGLE,
193 multiple = MPI_THREAD_MULTIPLE
208 static bool is_finalized()
211 GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
215 static bool is_initialized()
218 GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
238 const thread_type thread_t = thread_type::serialized)
240 this->required_thread_support_ =
static_cast<int>(thread_t);
241 GKO_ASSERT_NO_MPI_ERRORS(
242 MPI_Init_thread(&argc, &argv, this->required_thread_support_,
243 &(this->provided_thread_support_)));
257 int required_thread_support_;
258 int provided_thread_support_;
271 using pointer = MPI_Comm*;
272 void operator()(pointer comm)
const
274 GKO_ASSERT(*comm != MPI_COMM_NULL);
275 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
298 MPI_Status*
get() {
return &this->status_; }
310 template <
typename T>
344 this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
351 if (req_ != MPI_REQUEST_NULL) {
352 if (MPI_Request_free(&req_) != MPI_SUCCESS) {
364 MPI_Request*
get() {
return &this->req_; }
375 GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_,
status.
get()));
392inline std::vector<status>
wait_all(std::vector<request>& req)
394 std::vector<status> stat;
395 for (std::size_t i = 0; i < req.size(); ++i) {
396 stat.emplace_back(req[i].wait());
429 : comm_(), force_host_buffer_(force_host_buffer)
431 this->comm_.reset(
new MPI_Comm(comm));
445 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
446 this->comm_.reset(
new MPI_Comm(comm_out), comm_deleter{});
460 GKO_ASSERT_NO_MPI_ERRORS(
461 MPI_Comm_split(comm.
get(), color, key, &comm_out));
462 this->comm_.reset(
new MPI_Comm(comm_out), comm_deleter{});
475 bool force_host_buffer =
false)
477 communicator comm_out(MPI_COMM_NULL, force_host_buffer);
478 comm_out.comm_.reset(
new MPI_Comm(comm), comm_deleter{});
507 if (
this != &other) {
508 comm_ = std::exchange(other.comm_,
509 std::make_shared<MPI_Comm>(MPI_COMM_NULL));
510 force_host_buffer_ = other.force_host_buffer_;
520 const MPI_Comm&
get()
const {
return *(this->comm_.get()); }
522 bool force_host_buffer()
const {
return force_host_buffer_; }
529 int size()
const {
return get_num_ranks(); }
536 int rank()
const {
return get_my_rank(); };
570 if (
get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
571 return get() == rhs.get();
574 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(
get(), rhs.get(), &flag));
575 return flag == MPI_IDENT;
592 if (
get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
593 return get() == rhs.get();
596 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(
get(), rhs.get(), &flag));
597 return flag == MPI_CONGRUENT;
606 GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->
get()));
622 template <
typename SendType>
623 void send(std::shared_ptr<const Executor> exec,
const SendType* send_buffer,
624 const int send_count,
const int destination_rank,
625 const int send_tag)
const
627 auto guard = exec->get_scoped_device_id_guard();
628 GKO_ASSERT_NO_MPI_ERRORS(
630 destination_rank, send_tag, this->
get()));
649 template <
typename SendType>
651 const SendType* send_buffer,
const int send_count,
652 const int destination_rank,
const int send_tag)
const
654 auto guard = exec->get_scoped_device_id_guard();
656 GKO_ASSERT_NO_MPI_ERRORS(
658 destination_rank, send_tag, this->
get(), req.
get()));
677 template <
typename RecvType>
678 status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
679 const int recv_count,
const int source_rank,
680 const int recv_tag)
const
682 auto guard = exec->get_scoped_device_id_guard();
684 GKO_ASSERT_NO_MPI_ERRORS(
686 source_rank, recv_tag, this->
get(), st.
get()));
705 template <
typename RecvType>
707 const int recv_count,
const int source_rank,
708 const int recv_tag)
const
710 auto guard = exec->get_scoped_device_id_guard();
712 GKO_ASSERT_NO_MPI_ERRORS(
714 source_rank, recv_tag, this->
get(), req.
get()));
730 template <
typename BroadcastType>
731 void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
732 int count,
int root_rank)
const
734 auto guard = exec->get_scoped_device_id_guard();
735 GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
737 root_rank, this->
get()));
755 template <
typename BroadcastType>
757 BroadcastType* buffer,
int count,
int root_rank)
const
759 auto guard = exec->get_scoped_device_id_guard();
761 GKO_ASSERT_NO_MPI_ERRORS(
763 root_rank, this->
get(), req.
get()));
781 template <
typename ReduceType>
782 void reduce(std::shared_ptr<const Executor> exec,
783 const ReduceType* send_buffer, ReduceType* recv_buffer,
784 int count, MPI_Op operation,
int root_rank)
const
786 auto guard = exec->get_scoped_device_id_guard();
787 GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
789 operation, root_rank, this->
get()));
808 template <
typename ReduceType>
810 const ReduceType* send_buffer, ReduceType* recv_buffer,
811 int count, MPI_Op operation,
int root_rank)
const
813 auto guard = exec->get_scoped_device_id_guard();
815 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
817 operation, root_rank, this->
get(), req.
get()));
834 template <
typename ReduceType>
836 ReduceType* recv_buffer,
int count, MPI_Op operation)
const
838 auto guard = exec->get_scoped_device_id_guard();
839 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
841 operation, this->
get()));
859 template <
typename ReduceType>
861 ReduceType* recv_buffer,
int count,
862 MPI_Op operation)
const
864 auto guard = exec->get_scoped_device_id_guard();
866 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
868 operation, this->
get(), req.
get()));
886 template <
typename ReduceType>
888 const ReduceType* send_buffer, ReduceType* recv_buffer,
889 int count, MPI_Op operation)
const
891 auto guard = exec->get_scoped_device_id_guard();
892 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
894 operation, this->
get()));
913 template <
typename ReduceType>
915 const ReduceType* send_buffer, ReduceType* recv_buffer,
916 int count, MPI_Op operation)
const
918 auto guard = exec->get_scoped_device_id_guard();
920 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
922 operation, this->
get(), req.
get()));
942 template <
typename SendType,
typename RecvType>
943 void gather(std::shared_ptr<const Executor> exec,
944 const SendType* send_buffer,
const int send_count,
945 RecvType* recv_buffer,
const int recv_count,
948 auto guard = exec->get_scoped_device_id_guard();
949 GKO_ASSERT_NO_MPI_ERRORS(
952 root_rank, this->
get()));
974 template <
typename SendType,
typename RecvType>
976 const SendType* send_buffer,
const int send_count,
977 RecvType* recv_buffer,
const int recv_count,
980 auto guard = exec->get_scoped_device_id_guard();
982 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
1007 template <
typename SendType,
typename RecvType>
1009 const SendType* send_buffer,
const int send_count,
1010 RecvType* recv_buffer,
const int* recv_counts,
1011 const int* displacements,
int root_rank)
const
1013 auto guard = exec->get_scoped_device_id_guard();
1014 GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
1016 recv_buffer, recv_counts, displacements,
1040 template <
typename SendType,
typename RecvType>
1042 const SendType* send_buffer,
const int send_count,
1043 RecvType* recv_buffer,
const int* recv_counts,
1044 const int* displacements,
int root_rank)
const
1046 auto guard = exec->get_scoped_device_id_guard();
1048 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
1050 recv_buffer, recv_counts, displacements,
1071 template <
typename SendType,
typename RecvType>
1073 const SendType* send_buffer,
const int send_count,
1074 RecvType* recv_buffer,
const int recv_count)
const
1076 auto guard = exec->get_scoped_device_id_guard();
1077 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
1101 template <
typename SendType,
typename RecvType>
1103 const SendType* send_buffer,
const int send_count,
1104 RecvType* recv_buffer,
const int recv_count)
const
1106 auto guard = exec->get_scoped_device_id_guard();
1108 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1111 this->
get(), req.
get()));
1130 template <
typename SendType,
typename RecvType>
1131 void scatter(std::shared_ptr<const Executor> exec,
1132 const SendType* send_buffer,
const int send_count,
1133 RecvType* recv_buffer,
const int recv_count,
1134 int root_rank)
const
1136 auto guard = exec->get_scoped_device_id_guard();
1137 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1161 template <
typename SendType,
typename RecvType>
1163 const SendType* send_buffer,
const int send_count,
1164 RecvType* recv_buffer,
const int recv_count,
1165 int root_rank)
const
1167 auto guard = exec->get_scoped_device_id_guard();
1169 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1172 this->
get(), req.
get()));
1194 template <
typename SendType,
typename RecvType>
1196 const SendType* send_buffer,
const int* send_counts,
1197 const int* displacements, RecvType* recv_buffer,
1198 const int recv_count,
int root_rank)
const
1200 auto guard = exec->get_scoped_device_id_guard();
1201 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1202 send_buffer, send_counts, displacements,
1227 template <
typename SendType,
typename RecvType>
1229 const SendType* send_buffer,
const int* send_counts,
1230 const int* displacements, RecvType* recv_buffer,
1231 const int recv_count,
int root_rank)
const
1233 auto guard = exec->get_scoped_device_id_guard();
1235 GKO_ASSERT_NO_MPI_ERRORS(
1236 MPI_Iscatterv(send_buffer, send_counts, displacements,
1239 root_rank, this->
get(), req.
get()));
1259 template <
typename RecvType>
1260 void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1261 const int recv_count)
const
1263 auto guard = exec->get_scoped_device_id_guard();
1264 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1288 template <
typename RecvType>
1290 RecvType* recv_buffer,
const int recv_count)
const
1292 auto guard = exec->get_scoped_device_id_guard();
1294 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1297 this->
get(), req.
get()));
1317 template <
typename SendType,
typename RecvType>
1319 const SendType* send_buffer,
const int send_count,
1320 RecvType* recv_buffer,
const int recv_count)
const
1322 auto guard = exec->get_scoped_device_id_guard();
1323 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1347 template <
typename SendType,
typename RecvType>
1349 const SendType* send_buffer,
const int send_count,
1350 RecvType* recv_buffer,
const int recv_count)
const
1352 auto guard = exec->get_scoped_device_id_guard();
1354 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1357 this->
get(), req.
get()));
1380 template <
typename SendType,
typename RecvType>
1382 const SendType* send_buffer,
const int* send_counts,
1383 const int* send_offsets, RecvType* recv_buffer,
1384 const int* recv_counts,
const int* recv_offsets)
const
1386 this->
all_to_all_v(std::move(exec), send_buffer, send_counts,
1388 recv_buffer, recv_counts, recv_offsets,
1408 const void* send_buffer,
const int* send_counts,
1409 const int* send_offsets, MPI_Datatype send_type,
1410 void* recv_buffer,
const int* recv_counts,
1411 const int* recv_offsets, MPI_Datatype recv_type)
const
1413 auto guard = exec->get_scoped_device_id_guard();
1414 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1415 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1416 recv_counts, recv_offsets, recv_type, this->
get()));
1439 const void* send_buffer,
const int* send_counts,
1440 const int* send_offsets, MPI_Datatype send_type,
1441 void* recv_buffer,
const int* recv_counts,
1442 const int* recv_offsets,
1443 MPI_Datatype recv_type)
const
1445 auto guard = exec->get_scoped_device_id_guard();
1447 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1448 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1449 recv_counts, recv_offsets, recv_type, this->
get(), req.
get()));
1473 template <
typename SendType,
typename RecvType>
1475 const SendType* send_buffer,
const int* send_counts,
1476 const int* send_offsets, RecvType* recv_buffer,
1477 const int* recv_counts,
1478 const int* recv_offsets)
const
1481 std::move(exec), send_buffer, send_counts, send_offsets,
1500 template <
typename ScanType>
1501 void scan(std::shared_ptr<const Executor> exec,
const ScanType* send_buffer,
1502 ScanType* recv_buffer,
int count, MPI_Op operation)
const
1504 auto guard = exec->get_scoped_device_id_guard();
1505 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1507 operation, this->
get()));
1526 template <
typename ScanType>
1528 const ScanType* send_buffer, ScanType* recv_buffer,
1529 int count, MPI_Op operation)
const
1531 auto guard = exec->get_scoped_device_id_guard();
1533 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1535 operation, this->
get(), req.
get()));
1540 std::shared_ptr<MPI_Comm> comm_;
1541 bool force_host_buffer_;
1543 int get_my_rank()
const
1546 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(
get(), &my_rank));
1550 int get_node_local_rank()
const
1552 MPI_Comm local_comm;
1554 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1555 this->
get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1556 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &
rank));
1557 MPI_Comm_free(&local_comm);
1561 int get_num_ranks()
const
1564 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->
get(), &size));
1594template <
typename ValueType>
1600 enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1633 window_ = std::exchange(other.window_, MPI_WIN_NULL);
1648 window(std::shared_ptr<const Executor> exec, ValueType* base,
int num_elems,
1649 const communicator& comm,
const int disp_unit =
sizeof(ValueType),
1650 MPI_Info input_info = MPI_INFO_NULL,
1653 auto guard = exec->get_scoped_device_id_guard();
1654 unsigned size = num_elems *
sizeof(ValueType);
1655 if (c_type == create_type::create) {
1656 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1657 base, size, disp_unit, input_info, comm.
get(), &this->window_));
1658 }
else if (c_type == create_type::dynamic_create) {
1659 GKO_ASSERT_NO_MPI_ERRORS(
1660 MPI_Win_create_dynamic(input_info, comm.
get(), &this->window_));
1661 }
else if (c_type == create_type::allocate) {
1662 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1663 size, disp_unit, input_info, comm.
get(), base, &this->window_));
1665 GKO_NOT_IMPLEMENTED;
1684 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1696 int assert = 0)
const
1698 if (lock_t == lock_type::shared) {
1699 GKO_ASSERT_NO_MPI_ERRORS(
1700 MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1701 }
else if (lock_t == lock_type::exclusive) {
1702 GKO_ASSERT_NO_MPI_ERRORS(
1703 MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1705 GKO_NOT_IMPLEMENTED;
1717 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1728 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1737 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1748 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1759 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1768 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1777 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1783 void sync()
const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1790 if (this->window_ && this->window_ != MPI_WIN_NULL) {
1791 MPI_Win_free(&this->window_);
1805 template <
typename PutType>
1806 void put(std::shared_ptr<const Executor> exec,
const PutType* origin_buffer,
1807 const int origin_count,
const int target_rank,
1808 const unsigned int target_disp,
const int target_count)
const
1810 auto guard = exec->get_scoped_device_id_guard();
1811 GKO_ASSERT_NO_MPI_ERRORS(
1813 target_rank, target_disp, target_count,
1829 template <
typename PutType>
1831 const PutType* origin_buffer,
const int origin_count,
1832 const int target_rank,
const unsigned int target_disp,
1833 const int target_count)
const
1835 auto guard = exec->get_scoped_device_id_guard();
1837 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1839 target_rank, target_disp, target_count,
1855 template <
typename PutType>
1857 const PutType* origin_buffer,
const int origin_count,
1858 const int target_rank,
const unsigned int target_disp,
1859 const int target_count, MPI_Op operation)
const
1861 auto guard = exec->get_scoped_device_id_guard();
1862 GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1864 target_rank, target_disp, target_count,
1881 template <
typename PutType>
1883 const PutType* origin_buffer,
const int origin_count,
1884 const int target_rank,
const unsigned int target_disp,
1885 const int target_count, MPI_Op operation)
const
1887 auto guard = exec->get_scoped_device_id_guard();
1889 GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1891 target_rank, target_disp, target_count,
1907 template <
typename GetType>
1908 void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1909 const int origin_count,
const int target_rank,
1910 const unsigned int target_disp,
const int target_count)
const
1912 auto guard = exec->get_scoped_device_id_guard();
1913 GKO_ASSERT_NO_MPI_ERRORS(
1915 target_rank, target_disp, target_count,
1931 template <
typename GetType>
1932 request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1933 const int origin_count,
const int target_rank,
1934 const unsigned int target_disp,
const int target_count)
const
1936 auto guard = exec->get_scoped_device_id_guard();
1938 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1940 target_rank, target_disp, target_count,
1958 template <
typename GetType>
1960 GetType* origin_buffer,
const int origin_count,
1961 GetType* result_buffer,
const int result_count,
1962 const int target_rank,
const unsigned int target_disp,
1963 const int target_count, MPI_Op operation)
const
1965 auto guard = exec->get_scoped_device_id_guard();
1966 GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1969 target_rank, target_disp, target_count,
1988 template <
typename GetType>
1990 GetType* origin_buffer,
const int origin_count,
1991 GetType* result_buffer,
const int result_count,
1992 const int target_rank,
1993 const unsigned int target_disp,
1994 const int target_count, MPI_Op operation)
const
1996 auto guard = exec->get_scoped_device_id_guard();
1998 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
2001 target_rank, target_disp, target_count,
2017 template <
typename GetType>
2019 GetType* origin_buffer, GetType* result_buffer,
2020 const int target_rank,
const unsigned int target_disp,
2021 MPI_Op operation)
const
2023 auto guard = exec->get_scoped_device_id_guard();
2024 GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
2026 target_rank, target_disp, operation, this->
get_window()));
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Definition mpi.hpp:678
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Definition mpi.hpp:1195
communicator(const communicator &other)=default
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Definition mpi.hpp:756
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Definition mpi.hpp:943
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Definition mpi.hpp:706
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Definition mpi.hpp:1228
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Definition mpi.hpp:1318
bool is_identical(const communicator &rhs) const
Definition mpi.hpp:568
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Definition mpi.hpp:1348
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Definition mpi.hpp:1438
bool operator!=(const communicator &rhs) const
Definition mpi.hpp:557
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Definition mpi.hpp:1131
void synchronize() const
Definition mpi.hpp:604
int rank() const
Definition mpi.hpp:536
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Definition mpi.hpp:809
int size() const
Definition mpi.hpp:529
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Definition mpi.hpp:623
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Definition mpi.hpp:1474
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Definition mpi.hpp:975
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
Definition mpi.hpp:1260
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Definition mpi.hpp:1381
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
Definition mpi.hpp:860
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
Definition mpi.hpp:1289
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Definition mpi.hpp:1407
int node_local_rank() const
Definition mpi.hpp:543
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Definition mpi.hpp:731
static communicator create_owning(const MPI_Comm &comm, bool force_host_buffer=false)
Definition mpi.hpp:474
const MPI_Comm & get() const
Definition mpi.hpp:520
communicator(const MPI_Comm &comm, int color, int key)
Definition mpi.hpp:442
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
Definition mpi.hpp:835
communicator & operator=(const communicator &other)=default
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Definition mpi.hpp:1072
communicator & operator=(communicator &&other)
Definition mpi.hpp:505
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Definition mpi.hpp:1102
bool operator==(const communicator &rhs) const
Definition mpi.hpp:550
bool is_congruent(const communicator &rhs) const
Definition mpi.hpp:590
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Definition mpi.hpp:887
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Definition mpi.hpp:1041
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Definition mpi.hpp:914
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Definition mpi.hpp:428
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Definition mpi.hpp:1527
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Definition mpi.hpp:782
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Definition mpi.hpp:1162
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Definition mpi.hpp:1501
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Definition mpi.hpp:1008
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Definition mpi.hpp:650
communicator(communicator &&other)
Definition mpi.hpp:495
communicator(const communicator &comm, int color, int key)
Definition mpi.hpp:457
MPI_Datatype get() const
Definition mpi.hpp:178
contiguous_type(int count, MPI_Datatype old_type)
Definition mpi.hpp:117
contiguous_type()
Definition mpi.hpp:126
contiguous_type(const contiguous_type &)=delete
contiguous_type(contiguous_type &&other) noexcept
Definition mpi.hpp:143
contiguous_type & operator=(contiguous_type &&other) noexcept
Definition mpi.hpp:155
contiguous_type & operator=(const contiguous_type &)=delete
~contiguous_type()
Definition mpi.hpp:166
~environment()
Definition mpi.hpp:249
int get_provided_thread_support() const
Definition mpi.hpp:227
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Definition mpi.hpp:237
request()
Definition mpi.hpp:333
MPI_Request * get()
Definition mpi.hpp:364
status wait()
Definition mpi.hpp:372
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Definition mpi.hpp:1908
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Definition mpi.hpp:1830
window()
Definition mpi.hpp:1610
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Definition mpi.hpp:1959
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Definition mpi.hpp:1806
~window()
Definition mpi.hpp:1788
lock_type
Definition mpi.hpp:1605
window & operator=(window &&other)
Definition mpi.hpp:1631
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Definition mpi.hpp:1882
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Definition mpi.hpp:1989
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Definition mpi.hpp:2018
void sync() const
Definition mpi.hpp:1783
void unlock(int rank) const
Definition mpi.hpp:1715
void fence(int assert=0) const
Definition mpi.hpp:1682
void flush(int rank) const
Definition mpi.hpp:1746
void unlock_all() const
Definition mpi.hpp:1735
create_type
Definition mpi.hpp:1600
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Definition mpi.hpp:1648
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Definition mpi.hpp:1856
void lock_all(int assert=0) const
Definition mpi.hpp:1726
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Definition mpi.hpp:1695
void flush_all_local() const
Definition mpi.hpp:1775
window(window &&other)
Definition mpi.hpp:1622
void flush_local(int rank) const
Definition mpi.hpp:1757
MPI_Win get_window() const
Definition mpi.hpp:1674
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Definition mpi.hpp:1932
void flush_all() const
Definition mpi.hpp:1766
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
double get_walltime()
Definition mpi.hpp:1583
constexpr bool is_gpu_aware()
Definition mpi.hpp:42
thread_type
Definition mpi.hpp:189
std::vector< status > wait_all(std::vector< request > &req)
Definition mpi.hpp:392
The Ginkgo namespace.
Definition abstract_factory.hpp:20
int get_count(const T *data) const
Definition mpi.hpp:311
status()
Definition mpi.hpp:291
MPI_Status * get()
Definition mpi.hpp:298