5#ifndef GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
13#include <ginkgo/core/base/lin_op.hpp>
14#include <ginkgo/core/base/math.hpp>
15#include <ginkgo/core/log/logger.hpp>
16#include <ginkgo/core/matrix/dense.hpp>
17#include <ginkgo/core/matrix/identity.hpp>
18#include <ginkgo/core/solver/workspace.hpp>
19#include <ginkgo/core/stop/combined.hpp>
20#include <ginkgo/core/stop/criterion.hpp>
23GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
67 friend class multigrid::detail::MultigridState;
160template <
typename DerivedType>
163 friend class multigrid::detail::MultigridState;
177 self()->template log<log::Logger::linop_apply_started>(self(), b, x);
178 auto exec = self()->get_executor();
179 GKO_ASSERT_CONFORMANT(self(), b);
180 GKO_ASSERT_EQUAL_ROWS(self(), x);
181 GKO_ASSERT_EQUAL_COLS(b, x);
185 self()->template log<log::Logger::linop_apply_completed>(self(), b, x);
196 self()->template log<log::Logger::linop_advanced_apply_started>(
197 self(), alpha, b, beta, x);
198 auto exec = self()->get_executor();
199 GKO_ASSERT_CONFORMANT(self(), b);
200 GKO_ASSERT_EQUAL_ROWS(self(), x);
201 GKO_ASSERT_EQUAL_COLS(b, x);
202 GKO_ASSERT_EQUAL_DIMENSIONS(alpha,
dim<2>(1, 1));
203 GKO_ASSERT_EQUAL_DIMENSIONS(beta,
dim<2>(1, 1));
209 self()->template log<log::Logger::linop_advanced_apply_completed>(
210 self(), alpha, b, beta, x);
229 GKO_ENABLE_SELF(DerivedType);
237template <
typename Solver>
240 static int num_vectors(
const Solver&) {
return 0; }
242 static int num_arrays(
const Solver&) {
return 0; }
244 static std::vector<std::string> op_names(
const Solver&) {
return {}; }
246 static std::vector<std::string> array_names(
const Solver&) {
return {}; }
248 static std::vector<int> scalars(
const Solver&) {
return {}; }
250 static std::vector<int> vectors(
const Solver&) {
return {}; }
269template <
typename DerivedType>
280 auto exec = self()->get_executor();
282 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_precond);
283 GKO_ASSERT_IS_SQUARE_MATRIX(new_precond);
284 if (new_precond->get_executor() != exec) {
297 if (&other !=
this) {
310 if (&other !=
this) {
312 other.set_preconditioner(
nullptr);
338 *
this = std::move(other);
342 DerivedType* self() {
return static_cast<DerivedType*
>(
this); }
344 const DerivedType* self()
const
346 return static_cast<const DerivedType*
>(
this);
361class SolverBaseLinOp {
363 SolverBaseLinOp(std::shared_ptr<const Executor> exec)
364 : workspace_{
std::move(exec)}
367 virtual ~SolverBaseLinOp() =
default;
374 std::shared_ptr<const LinOp> get_system_matrix()
const
376 return system_matrix_;
379 const LinOp* get_workspace_op(
int vector_id)
const
381 return workspace_.get_op(vector_id);
384 virtual int get_num_workspace_ops()
const {
return 0; }
386 virtual std::vector<std::string> get_workspace_op_names()
const
395 virtual std::vector<int> get_workspace_scalars()
const {
return {}; }
401 virtual std::vector<int> get_workspace_vectors()
const {
return {}; }
404 void set_system_matrix_base(std::shared_ptr<const LinOp> system_matrix)
406 system_matrix_ = std::move(system_matrix);
409 void set_workspace_size(
int num_operators,
int num_arrays)
const
411 workspace_.set_size(num_operators, num_arrays);
414 template <
typename LinOpType>
415 LinOpType* create_workspace_op(
int vector_id,
gko::dim<2> size)
const
417 return workspace_.template create_or_get_op<LinOpType>(
420 return LinOpType::create(this->workspace_.get_executor(), size);
422 typeid(LinOpType), size, size[1]);
425 template <
typename LinOpType>
426 LinOpType* create_workspace_op_with_config_of(
int vector_id,
427 const LinOpType*
vec)
const
429 return workspace_.template create_or_get_op<LinOpType>(
430 vector_id, [&] {
return LinOpType::create_with_config_of(
vec); },
434 template <
typename LinOpType>
435 LinOpType* create_workspace_op_with_type_of(
int vector_id,
436 const LinOpType*
vec,
439 return workspace_.template create_or_get_op<LinOpType>(
442 return LinOpType::create_with_type_of(
445 typeid(*vec), size, size[1]);
448 template <
typename LinOpType>
449 LinOpType* create_workspace_op_with_type_of(
int vector_id,
450 const LinOpType*
vec,
452 dim<2> local_size)
const
454 return workspace_.template create_or_get_op<LinOpType>(
457 return LinOpType::create_with_type_of(
461 typeid(*vec), global_size, local_size[1]);
464 template <
typename ValueType>
465 matrix::Dense<ValueType>* create_workspace_scalar(
int vector_id,
466 size_type size)
const
468 return workspace_.template create_or_get_op<matrix::Dense<ValueType>>(
472 workspace_.get_executor(), dim<2>{1, size});
474 typeid(matrix::Dense<ValueType>),
gko::dim<2>{1, size}, size);
477 template <
typename ValueType>
478 array<ValueType>& create_workspace_array(
int array_id, size_type size)
const
480 return workspace_.template create_or_get_array<ValueType>(array_id,
484 template <
typename ValueType>
485 array<ValueType>& create_workspace_array(
int array_id)
const
487 return workspace_.template init_or_get_array<ValueType>(array_id);
491 mutable detail::workspace workspace_;
493 std::shared_ptr<const LinOp> system_matrix_;
500template <
typename MatrixType>
503 GKO_DEPRECATED(
"This class will be replaced by the template-less detail::SolverBaseLinOp in a future release")
SolverBase
505 : public detail::SolverBaseLinOp {
507 using detail::SolverBaseLinOp::SolverBaseLinOp;
518 return std::dynamic_pointer_cast<const MatrixType>(
519 SolverBaseLinOp::get_system_matrix());
523 void set_system_matrix_base(std::shared_ptr<const MatrixType> system_matrix)
525 SolverBaseLinOp::set_system_matrix_base(std::move(system_matrix));
538template <
typename DerivedType,
typename MatrixType = LinOp>
547 if (&other !=
this) {
559 if (&other !=
this) {
560 set_system_matrix(other.get_system_matrix());
561 other.set_system_matrix(
nullptr);
568 EnableSolverBase(std::shared_ptr<const MatrixType> system_matrix)
569 : SolverBase<MatrixType>{self()->get_executor()}
571 set_system_matrix(std::move(system_matrix));
578 :
SolverBase<MatrixType>{other.self()->get_executor()}
588 :
SolverBase<MatrixType>{other.self()->get_executor()}
590 *
this = std::move(other);
593 int get_num_workspace_ops()
const override
595 using traits = workspace_traits<DerivedType>;
596 return traits::num_vectors(*self());
599 std::vector<std::string> get_workspace_op_names()
const override
601 using traits = workspace_traits<DerivedType>;
602 return traits::op_names(*self());
612 return traits::scalars(*self());
622 return traits::vectors(*self());
626 void set_system_matrix(std::shared_ptr<const MatrixType> new_system_matrix)
628 auto exec = self()->get_executor();
629 if (new_system_matrix) {
630 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_system_matrix);
631 GKO_ASSERT_IS_SQUARE_MATRIX(new_system_matrix);
632 if (new_system_matrix->get_executor() != exec) {
633 new_system_matrix =
gko::clone(exec, new_system_matrix);
636 this->set_system_matrix_base(new_system_matrix);
639 void setup_workspace()
const
641 using traits = workspace_traits<DerivedType>;
642 this->set_workspace_size(traits::num_vectors(*self()),
643 traits::num_arrays(*self()));
647 DerivedType* self() {
return static_cast<DerivedType*
>(
this); }
649 const DerivedType* self()
const
651 return static_cast<const DerivedType*
>(
this);
671 return stop_factory_;
680 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
682 stop_factory_ = new_stop_factory;
686 std::shared_ptr<const stop::CriterionFactory> stop_factory_;
698template <
typename DerivedType>
707 if (&other !=
this) {
720 if (&other !=
this) {
722 other.set_stop_criterion_factory(
nullptr);
730 std::shared_ptr<const stop::CriterionFactory> stop_factory)
746 *
this = std::move(other);
750 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
override
752 auto exec = self()->get_executor();
753 if (new_stop_factory && new_stop_factory->get_executor() != exec) {
754 new_stop_factory =
gko::clone(exec, new_stop_factory);
760 DerivedType* self() {
return static_cast<DerivedType*
>(
this); }
762 const DerivedType* self()
const
764 return static_cast<const DerivedType*
>(
this);
778template <
typename ValueType,
typename DerivedType>
787 std::shared_ptr<const LinOp> system_matrix,
788 std::shared_ptr<const stop::CriterionFactory> stop_factory,
789 std::shared_ptr<const LinOp> preconditioner)
795 template <
typename FactoryParameters>
797 std::shared_ptr<const LinOp> system_matrix,
798 const FactoryParameters& params)
801 generate_preconditioner(system_matrix, params)}
805 template <
typename FactoryParameters>
806 static std::shared_ptr<const LinOp> generate_preconditioner(
807 std::shared_ptr<const LinOp> system_matrix,
808 const FactoryParameters& params)
810 if (params.generated_preconditioner) {
811 return params.generated_preconditioner;
812 }
else if (params.preconditioner) {
813 return params.preconditioner->generate(system_matrix);
816 system_matrix->get_executor(), system_matrix->get_size());
822template <
typename Parameters,
typename Factory>
828 std::vector<std::shared_ptr<const stop::CriterionFactory>>
833template <
typename Parameters,
typename Factory>
840 std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
856GKO_END_DISABLE_DEPRECATION_WARNINGS
Definition lin_op.hpp:117
std::shared_ptr< const Executor > get_executor() const noexcept
Definition polymorphic_object.hpp:243
Definition lin_op.hpp:681
virtual void set_preconditioner(std::shared_ptr< const LinOp > new_precond)
Definition lin_op.hpp:701
virtual std::shared_ptr< const LinOp > get_preconditioner() const
Definition lin_op.hpp:690
Definition abstract_factory.hpp:211
size_type get_stride() const noexcept
Definition dense.hpp:869
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
static std::unique_ptr< Identity > create(std::shared_ptr< const Executor > exec, dim< 2 > size)
Definition utils_helper.hpp:41
T * get() const
Definition utils_helper.hpp:75
Definition solver_base.hpp:65
void set_default_initial_guess(initial_guess_mode guess)
Definition solver_base.hpp:141
initial_guess_mode get_default_initial_guess() const
Definition solver_base.hpp:123
virtual void apply_with_initial_guess(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x, initial_guess_mode guess) const =0
ApplyWithInitialGuess(initial_guess_mode guess=initial_guess_mode::provided)
Definition solver_base.hpp:131
virtual void apply_with_initial_guess(const LinOp *b, LinOp *x, initial_guess_mode guess) const =0
Definition solver_base.hpp:161
void apply_with_initial_guess(const LinOp *b, LinOp *x, initial_guess_mode guess) const override
Definition solver_base.hpp:174
virtual void apply_with_initial_guess_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x, initial_guess_mode guess) const =0
virtual void apply_with_initial_guess_impl(const LinOp *b, LinOp *x, initial_guess_mode guess) const =0
void apply_with_initial_guess(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x, initial_guess_mode guess) const override
Definition solver_base.hpp:192
Definition solver_base.hpp:699
EnableIterativeBase & operator=(EnableIterativeBase &&other)
Definition solver_base.hpp:718
void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory) override
Definition solver_base.hpp:749
EnableIterativeBase(EnableIterativeBase &&other)
Definition solver_base.hpp:744
EnableIterativeBase(const EnableIterativeBase &other)
Definition solver_base.hpp:738
EnableIterativeBase & operator=(const EnableIterativeBase &other)
Definition solver_base.hpp:705
Definition solver_base.hpp:270
EnablePreconditionable(const EnablePreconditionable &other)
Definition solver_base.hpp:327
EnablePreconditionable & operator=(EnablePreconditionable &&other)
Definition solver_base.hpp:308
EnablePreconditionable(EnablePreconditionable &&other)
Definition solver_base.hpp:336
EnablePreconditionable & operator=(const EnablePreconditionable &other)
Definition solver_base.hpp:295
void set_preconditioner(std::shared_ptr< const LinOp > new_precond) override
Definition solver_base.hpp:278
Definition solver_base.hpp:782
Definition solver_base.hpp:539
EnableSolverBase(EnableSolverBase &&other)
Definition solver_base.hpp:587
std::vector< int > get_workspace_vectors() const override
Definition solver_base.hpp:619
std::vector< int > get_workspace_scalars() const override
Definition solver_base.hpp:609
EnableSolverBase(const EnableSolverBase &other)
Definition solver_base.hpp:577
EnableSolverBase & operator=(EnableSolverBase &&other)
Definition solver_base.hpp:557
EnableSolverBase & operator=(const EnableSolverBase &other)
Definition solver_base.hpp:545
Definition solver_base.hpp:661
std::shared_ptr< const stop::CriterionFactory > get_stop_criterion_factory() const
Definition solver_base.hpp:668
virtual void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory)
Definition solver_base.hpp:679
Definition solver_base.hpp:505
std::shared_ptr< const MatrixType > get_system_matrix() const
Definition solver_base.hpp:516
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Definition abstract_factory.hpp:445
std::shared_ptr< const CriterionFactory > combine(FactoryContainer &&factories)
Definition combined.hpp:109
initial_guess_mode
Definition solver_base.hpp:33
The Ginkgo namespace.
Definition abstract_factory.hpp:20
detail::cloned_type< Pointer > clone(const Pointer &p)
Definition utils_helper.hpp:173
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Definition temporary_clone.hpp:208
Definition solver_base.hpp:824
std::vector< std::shared_ptr< const stop::CriterionFactory > > criteria
Definition solver_base.hpp:829
Definition solver_base.hpp:835
std::shared_ptr< const LinOp > generated_preconditioner
Definition solver_base.hpp:848
std::shared_ptr< const LinOpFactory > preconditioner
Definition solver_base.hpp:841
Definition solver_base.hpp:238