ginkgo/core/solver/solver_base.hpp Source File

ginkgo/core/solver/solver_base.hpp Source File#

Reference API: ginkgo/core/solver/solver_base.hpp Source File
Reference API
solver_base.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13#include <ginkgo/core/base/lin_op.hpp>
14#include <ginkgo/core/base/math.hpp>
15#include <ginkgo/core/log/logger.hpp>
16#include <ginkgo/core/matrix/dense.hpp>
17#include <ginkgo/core/matrix/identity.hpp>
18#include <ginkgo/core/solver/workspace.hpp>
19#include <ginkgo/core/stop/combined.hpp>
20#include <ginkgo/core/stop/criterion.hpp>
21
22
23GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
24
25
26namespace gko {
27namespace solver {
28
29
37 zero,
41 rhs,
46};
47
48
49namespace multigrid {
50namespace detail {
51
52
53class MultigridState;
54
55
56} // namespace detail
57} // namespace multigrid
58
59
66protected:
67 friend class multigrid::detail::MultigridState;
68
82 virtual void apply_with_initial_guess(const LinOp* b, LinOp* x,
83 initial_guess_mode guess) const = 0;
84
86 initial_guess_mode guess) const
87 {
88 apply_with_initial_guess(b.get(), x.get(), guess);
89 }
90
103 virtual void apply_with_initial_guess(const LinOp* alpha, const LinOp* b,
104 const LinOp* beta, LinOp* x,
105 initial_guess_mode guess) const = 0;
106
107
112 initial_guess_mode guess) const
113 {
114 apply_with_initial_guess(alpha.get(), b.get(), beta.get(), x.get(),
115 guess);
116 }
117
124
133 : guess_(guess)
134 {}
135
141 void set_default_initial_guess(initial_guess_mode guess) { guess_ = guess; }
142
143private:
144 initial_guess_mode guess_;
145};
146
147
160template <typename DerivedType>
162protected:
163 friend class multigrid::detail::MultigridState;
164
167 : ApplyWithInitialGuess(guess)
168 {}
169
175 initial_guess_mode guess) const override
176 {
177 self()->template log<log::Logger::linop_apply_started>(self(), b, x);
178 auto exec = self()->get_executor();
179 GKO_ASSERT_CONFORMANT(self(), b);
180 GKO_ASSERT_EQUAL_ROWS(self(), x);
181 GKO_ASSERT_EQUAL_COLS(b, x);
183 make_temporary_clone(exec, x).get(),
184 guess);
185 self()->template log<log::Logger::linop_apply_completed>(self(), b, x);
186 }
187
192 void apply_with_initial_guess(const LinOp* alpha, const LinOp* b,
193 const LinOp* beta, LinOp* x,
194 initial_guess_mode guess) const override
195 {
196 self()->template log<log::Logger::linop_advanced_apply_started>(
197 self(), alpha, b, beta, x);
198 auto exec = self()->get_executor();
199 GKO_ASSERT_CONFORMANT(self(), b);
200 GKO_ASSERT_EQUAL_ROWS(self(), x);
201 GKO_ASSERT_EQUAL_COLS(b, x);
202 GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1));
203 GKO_ASSERT_EQUAL_DIMENSIONS(beta, dim<2>(1, 1));
205 make_temporary_clone(exec, alpha).get(),
206 make_temporary_clone(exec, b).get(),
207 make_temporary_clone(exec, beta).get(),
208 make_temporary_clone(exec, x).get(), guess);
209 self()->template log<log::Logger::linop_advanced_apply_completed>(
210 self(), alpha, b, beta, x);
211 }
212
213 // TODO: should we provide the default implementation?
219 const LinOp* b, LinOp* x, initial_guess_mode guess) const = 0;
220
226 const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x,
227 initial_guess_mode guess) const = 0;
228
229 GKO_ENABLE_SELF(DerivedType);
230};
231
232
237template <typename Solver>
239 // number of vectors used by this workspace
240 static int num_vectors(const Solver&) { return 0; }
241 // number of arrays used by this workspace
242 static int num_arrays(const Solver&) { return 0; }
243 // array containing the num_vectors names for the workspace vectors
244 static std::vector<std::string> op_names(const Solver&) { return {}; }
245 // array containing the num_arrays names for the workspace vectors
246 static std::vector<std::string> array_names(const Solver&) { return {}; }
247 // array containing all scalar vectors (independent of problem size)
248 static std::vector<int> scalars(const Solver&) { return {}; }
249 // array containing all vectors (dependent on problem size)
250 static std::vector<int> vectors(const Solver&) { return {}; }
251};
252
253
269template <typename DerivedType>
271public:
278 void set_preconditioner(std::shared_ptr<const LinOp> new_precond) override
279 {
280 auto exec = self()->get_executor();
281 if (new_precond) {
282 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_precond);
283 GKO_ASSERT_IS_SQUARE_MATRIX(new_precond);
284 if (new_precond->get_executor() != exec) {
285 new_precond = gko::clone(exec, new_precond);
286 }
287 }
289 }
290
296 {
297 if (&other != this) {
299 }
300 return *this;
301 }
302
309 {
310 if (&other != this) {
311 set_preconditioner(other.get_preconditioner());
312 other.set_preconditioner(nullptr);
313 }
314 return *this;
315 }
316
317 EnablePreconditionable() = default;
318
319 EnablePreconditionable(std::shared_ptr<const LinOp> preconditioner)
320 {
321 set_preconditioner(std::move(preconditioner));
322 }
323
328 {
329 *this = other;
330 }
331
337 {
338 *this = std::move(other);
339 }
340
341private:
342 DerivedType* self() { return static_cast<DerivedType*>(this); }
343
344 const DerivedType* self() const
345 {
346 return static_cast<const DerivedType*>(this);
347 }
348};
349
350
351namespace detail {
352
353
361class SolverBaseLinOp {
362public:
363 SolverBaseLinOp(std::shared_ptr<const Executor> exec)
364 : workspace_{std::move(exec)}
365 {}
366
367 virtual ~SolverBaseLinOp() = default;
368
374 std::shared_ptr<const LinOp> get_system_matrix() const
375 {
376 return system_matrix_;
377 }
378
379 const LinOp* get_workspace_op(int vector_id) const
380 {
381 return workspace_.get_op(vector_id);
382 }
383
384 virtual int get_num_workspace_ops() const { return 0; }
385
386 virtual std::vector<std::string> get_workspace_op_names() const
387 {
388 return {};
389 }
390
395 virtual std::vector<int> get_workspace_scalars() const { return {}; }
396
401 virtual std::vector<int> get_workspace_vectors() const { return {}; }
402
403protected:
404 void set_system_matrix_base(std::shared_ptr<const LinOp> system_matrix)
405 {
406 system_matrix_ = std::move(system_matrix);
407 }
408
409 void set_workspace_size(int num_operators, int num_arrays) const
410 {
411 workspace_.set_size(num_operators, num_arrays);
412 }
413
414 template <typename LinOpType>
415 LinOpType* create_workspace_op(int vector_id, gko::dim<2> size) const
416 {
417 return workspace_.template create_or_get_op<LinOpType>(
418 vector_id,
419 [&] {
420 return LinOpType::create(this->workspace_.get_executor(), size);
421 },
422 typeid(LinOpType), size, size[1]);
423 }
424
425 template <typename LinOpType>
426 LinOpType* create_workspace_op_with_config_of(int vector_id,
427 const LinOpType* vec) const
428 {
429 return workspace_.template create_or_get_op<LinOpType>(
430 vector_id, [&] { return LinOpType::create_with_config_of(vec); },
431 typeid(*vec), vec->get_size(), vec->get_stride());
432 }
433
434 template <typename LinOpType>
435 LinOpType* create_workspace_op_with_type_of(int vector_id,
436 const LinOpType* vec,
437 dim<2> size) const
438 {
439 return workspace_.template create_or_get_op<LinOpType>(
440 vector_id,
441 [&] {
442 return LinOpType::create_with_type_of(
443 vec, workspace_.get_executor(), size, size[1]);
444 },
445 typeid(*vec), size, size[1]);
446 }
447
448 template <typename LinOpType>
449 LinOpType* create_workspace_op_with_type_of(int vector_id,
450 const LinOpType* vec,
451 dim<2> global_size,
452 dim<2> local_size) const
453 {
454 return workspace_.template create_or_get_op<LinOpType>(
455 vector_id,
456 [&] {
457 return LinOpType::create_with_type_of(
458 vec, workspace_.get_executor(), global_size, local_size,
459 local_size[1]);
460 },
461 typeid(*vec), global_size, local_size[1]);
462 }
463
464 template <typename ValueType>
465 matrix::Dense<ValueType>* create_workspace_scalar(int vector_id,
466 size_type size) const
467 {
468 return workspace_.template create_or_get_op<matrix::Dense<ValueType>>(
469 vector_id,
470 [&] {
472 workspace_.get_executor(), dim<2>{1, size});
473 },
474 typeid(matrix::Dense<ValueType>), gko::dim<2>{1, size}, size);
475 }
476
477 template <typename ValueType>
478 array<ValueType>& create_workspace_array(int array_id, size_type size) const
479 {
480 return workspace_.template create_or_get_array<ValueType>(array_id,
481 size);
482 }
483
484 template <typename ValueType>
485 array<ValueType>& create_workspace_array(int array_id) const
486 {
487 return workspace_.template init_or_get_array<ValueType>(array_id);
488 }
489
490private:
491 mutable detail::workspace workspace_;
492
493 std::shared_ptr<const LinOp> system_matrix_;
494};
495
496
497} // namespace detail
498
499
500template <typename MatrixType>
501class
502 // clang-format off
503 GKO_DEPRECATED("This class will be replaced by the template-less detail::SolverBaseLinOp in a future release") SolverBase
504 // clang-format on
505 : public detail::SolverBaseLinOp {
506public:
507 using detail::SolverBaseLinOp::SolverBaseLinOp;
508
516 std::shared_ptr<const MatrixType> get_system_matrix() const
517 {
518 return std::dynamic_pointer_cast<const MatrixType>(
519 SolverBaseLinOp::get_system_matrix());
520 }
521
522protected:
523 void set_system_matrix_base(std::shared_ptr<const MatrixType> system_matrix)
524 {
525 SolverBaseLinOp::set_system_matrix_base(std::move(system_matrix));
526 }
527};
528
529
538template <typename DerivedType, typename MatrixType = LinOp>
539class EnableSolverBase : public SolverBase<MatrixType> {
540public:
546 {
547 if (&other != this) {
548 set_system_matrix(other.get_system_matrix());
549 }
550 return *this;
551 }
552
558 {
559 if (&other != this) {
560 set_system_matrix(other.get_system_matrix());
561 other.set_system_matrix(nullptr);
562 }
563 return *this;
564 }
565
566 EnableSolverBase() : SolverBase<MatrixType>{self()->get_executor()} {}
567
568 EnableSolverBase(std::shared_ptr<const MatrixType> system_matrix)
569 : SolverBase<MatrixType>{self()->get_executor()}
570 {
571 set_system_matrix(std::move(system_matrix));
572 }
573
578 : SolverBase<MatrixType>{other.self()->get_executor()}
579 {
580 *this = other;
581 }
582
588 : SolverBase<MatrixType>{other.self()->get_executor()}
589 {
590 *this = std::move(other);
591 }
592
593 int get_num_workspace_ops() const override
594 {
595 using traits = workspace_traits<DerivedType>;
596 return traits::num_vectors(*self());
597 }
598
599 std::vector<std::string> get_workspace_op_names() const override
600 {
601 using traits = workspace_traits<DerivedType>;
602 return traits::op_names(*self());
603 }
604
609 std::vector<int> get_workspace_scalars() const override
610 {
611 using traits = workspace_traits<DerivedType>;
612 return traits::scalars(*self());
613 }
614
619 std::vector<int> get_workspace_vectors() const override
620 {
621 using traits = workspace_traits<DerivedType>;
622 return traits::vectors(*self());
623 }
624
625protected:
626 void set_system_matrix(std::shared_ptr<const MatrixType> new_system_matrix)
627 {
628 auto exec = self()->get_executor();
629 if (new_system_matrix) {
630 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_system_matrix);
631 GKO_ASSERT_IS_SQUARE_MATRIX(new_system_matrix);
632 if (new_system_matrix->get_executor() != exec) {
633 new_system_matrix = gko::clone(exec, new_system_matrix);
634 }
635 }
636 this->set_system_matrix_base(new_system_matrix);
637 }
638
639 void setup_workspace() const
640 {
641 using traits = workspace_traits<DerivedType>;
642 this->set_workspace_size(traits::num_vectors(*self()),
643 traits::num_arrays(*self()));
644 }
645
646private:
647 DerivedType* self() { return static_cast<DerivedType*>(this); }
648
649 const DerivedType* self() const
650 {
651 return static_cast<const DerivedType*>(this);
652 }
653};
654
655
662public:
668 std::shared_ptr<const stop::CriterionFactory> get_stop_criterion_factory()
669 const
670 {
671 return stop_factory_;
672 }
673
680 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
681 {
682 stop_factory_ = new_stop_factory;
683 }
684
685private:
686 std::shared_ptr<const stop::CriterionFactory> stop_factory_;
687};
688
689
698template <typename DerivedType>
700public:
706 {
707 if (&other != this) {
709 }
710 return *this;
711 }
712
719 {
720 if (&other != this) {
721 set_stop_criterion_factory(other.get_stop_criterion_factory());
722 other.set_stop_criterion_factory(nullptr);
723 }
724 return *this;
725 }
726
727 EnableIterativeBase() = default;
728
730 std::shared_ptr<const stop::CriterionFactory> stop_factory)
731 {
732 set_stop_criterion_factory(std::move(stop_factory));
733 }
734
738 EnableIterativeBase(const EnableIterativeBase& other) { *this = other; }
739
745 {
746 *this = std::move(other);
747 }
748
750 std::shared_ptr<const stop::CriterionFactory> new_stop_factory) override
751 {
752 auto exec = self()->get_executor();
753 if (new_stop_factory && new_stop_factory->get_executor() != exec) {
754 new_stop_factory = gko::clone(exec, new_stop_factory);
755 }
757 }
758
759private:
760 DerivedType* self() { return static_cast<DerivedType*>(this); }
761
762 const DerivedType* self() const
763 {
764 return static_cast<const DerivedType*>(this);
765 }
766};
767
768
778template <typename ValueType, typename DerivedType>
780 : public EnableSolverBase<DerivedType>,
781 public EnableIterativeBase<DerivedType>,
782 public EnablePreconditionable<DerivedType> {
783public:
785
787 std::shared_ptr<const LinOp> system_matrix,
788 std::shared_ptr<const stop::CriterionFactory> stop_factory,
789 std::shared_ptr<const LinOp> preconditioner)
790 : EnableSolverBase<DerivedType>(std::move(system_matrix)),
791 EnableIterativeBase<DerivedType>{std::move(stop_factory)},
792 EnablePreconditionable<DerivedType>{std::move(preconditioner)}
793 {}
794
795 template <typename FactoryParameters>
797 std::shared_ptr<const LinOp> system_matrix,
798 const FactoryParameters& params)
800 system_matrix, stop::combine(params.criteria),
801 generate_preconditioner(system_matrix, params)}
802 {}
803
804private:
805 template <typename FactoryParameters>
806 static std::shared_ptr<const LinOp> generate_preconditioner(
807 std::shared_ptr<const LinOp> system_matrix,
808 const FactoryParameters& params)
809 {
810 if (params.generated_preconditioner) {
811 return params.generated_preconditioner;
812 } else if (params.preconditioner) {
813 return params.preconditioner->generate(system_matrix);
814 } else {
816 system_matrix->get_executor(), system_matrix->get_size());
817 }
818 }
819};
820
821
822template <typename Parameters, typename Factory>
824 : enable_parameters_type<Parameters, Factory> {
828 std::vector<std::shared_ptr<const stop::CriterionFactory>>
829 GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(criteria);
830};
831
832
833template <typename Parameters, typename Factory>
835 : enable_iterative_solver_factory_parameters<Parameters, Factory> {
840 std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
842
847 std::shared_ptr<const LinOp> GKO_FACTORY_PARAMETER_SCALAR(
849};
850
851
852} // namespace solver
853} // namespace gko
854
855
856GKO_END_DISABLE_DEPRECATION_WARNINGS
857
858
859#endif // GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
Definition lin_op.hpp:117
std::shared_ptr< const Executor > get_executor() const noexcept
Definition polymorphic_object.hpp:243
Definition lin_op.hpp:681
virtual void set_preconditioner(std::shared_ptr< const LinOp > new_precond)
Definition lin_op.hpp:701
virtual std::shared_ptr< const LinOp > get_preconditioner() const
Definition lin_op.hpp:690
Definition abstract_factory.hpp:211
size_type get_stride() const noexcept
Definition dense.hpp:869
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
static std::unique_ptr< Identity > create(std::shared_ptr< const Executor > exec, dim< 2 > size)
Definition utils_helper.hpp:41
T * get() const
Definition utils_helper.hpp:75
Definition solver_base.hpp:65
void set_default_initial_guess(initial_guess_mode guess)
Definition solver_base.hpp:141
initial_guess_mode get_default_initial_guess() const
Definition solver_base.hpp:123
virtual void apply_with_initial_guess(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x, initial_guess_mode guess) const =0
ApplyWithInitialGuess(initial_guess_mode guess=initial_guess_mode::provided)
Definition solver_base.hpp:131
virtual void apply_with_initial_guess(const LinOp *b, LinOp *x, initial_guess_mode guess) const =0
Definition solver_base.hpp:161
void apply_with_initial_guess(const LinOp *b, LinOp *x, initial_guess_mode guess) const override
Definition solver_base.hpp:174
virtual void apply_with_initial_guess_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x, initial_guess_mode guess) const =0
virtual void apply_with_initial_guess_impl(const LinOp *b, LinOp *x, initial_guess_mode guess) const =0
void apply_with_initial_guess(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x, initial_guess_mode guess) const override
Definition solver_base.hpp:192
Definition solver_base.hpp:699
EnableIterativeBase & operator=(EnableIterativeBase &&other)
Definition solver_base.hpp:718
void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory) override
Definition solver_base.hpp:749
EnableIterativeBase(EnableIterativeBase &&other)
Definition solver_base.hpp:744
EnableIterativeBase(const EnableIterativeBase &other)
Definition solver_base.hpp:738
EnableIterativeBase & operator=(const EnableIterativeBase &other)
Definition solver_base.hpp:705
Definition solver_base.hpp:270
EnablePreconditionable(const EnablePreconditionable &other)
Definition solver_base.hpp:327
EnablePreconditionable & operator=(EnablePreconditionable &&other)
Definition solver_base.hpp:308
EnablePreconditionable(EnablePreconditionable &&other)
Definition solver_base.hpp:336
EnablePreconditionable & operator=(const EnablePreconditionable &other)
Definition solver_base.hpp:295
void set_preconditioner(std::shared_ptr< const LinOp > new_precond) override
Definition solver_base.hpp:278
Definition solver_base.hpp:539
EnableSolverBase(EnableSolverBase &&other)
Definition solver_base.hpp:587
std::vector< int > get_workspace_vectors() const override
Definition solver_base.hpp:619
std::vector< int > get_workspace_scalars() const override
Definition solver_base.hpp:609
EnableSolverBase(const EnableSolverBase &other)
Definition solver_base.hpp:577
EnableSolverBase & operator=(EnableSolverBase &&other)
Definition solver_base.hpp:557
EnableSolverBase & operator=(const EnableSolverBase &other)
Definition solver_base.hpp:545
Definition solver_base.hpp:661
std::shared_ptr< const stop::CriterionFactory > get_stop_criterion_factory() const
Definition solver_base.hpp:668
virtual void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory)
Definition solver_base.hpp:679
Definition solver_base.hpp:505
std::shared_ptr< const MatrixType > get_system_matrix() const
Definition solver_base.hpp:516
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Definition abstract_factory.hpp:445
std::shared_ptr< const CriterionFactory > combine(FactoryContainer &&factories)
Definition combined.hpp:109
initial_guess_mode
Definition solver_base.hpp:33
The Ginkgo namespace.
Definition abstract_factory.hpp:20
detail::cloned_type< Pointer > clone(const Pointer &p)
Definition utils_helper.hpp:173
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Definition temporary_clone.hpp:208
STL namespace.
Definition dim.hpp:26
std::vector< std::shared_ptr< const stop::CriterionFactory > > criteria
Definition solver_base.hpp:829
std::shared_ptr< const LinOp > generated_preconditioner
Definition solver_base.hpp:848
std::shared_ptr< const LinOpFactory > preconditioner
Definition solver_base.hpp:841
Definition solver_base.hpp:238