ginkgo/core/preconditioner/ic.hpp Source File

ginkgo/core/preconditioner/ic.hpp Source File#

Reference API: ginkgo/core/preconditioner/ic.hpp Source File
Reference API
ic.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
6#define GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
7
8
9#include <memory>
10#include <type_traits>
11
12#include <ginkgo/core/base/abstract_factory.hpp>
13#include <ginkgo/core/base/composition.hpp>
14#include <ginkgo/core/base/exception.hpp>
15#include <ginkgo/core/base/exception_helpers.hpp>
16#include <ginkgo/core/base/lin_op.hpp>
17#include <ginkgo/core/base/precision_dispatch.hpp>
18#include <ginkgo/core/config/config.hpp>
19#include <ginkgo/core/config/registry.hpp>
20#include <ginkgo/core/factorization/par_ic.hpp>
21#include <ginkgo/core/matrix/dense.hpp>
22#include <ginkgo/core/preconditioner/isai.hpp>
23#include <ginkgo/core/preconditioner/utils.hpp>
24#include <ginkgo/core/solver/gmres.hpp>
25#include <ginkgo/core/solver/ir.hpp>
26#include <ginkgo/core/solver/solver_traits.hpp>
27#include <ginkgo/core/solver/triangular.hpp>
28#include <ginkgo/core/stop/combined.hpp>
29#include <ginkgo/core/stop/iteration.hpp>
30#include <ginkgo/core/stop/residual_norm.hpp>
31
32
33namespace gko {
34namespace preconditioner {
35namespace detail {
36
37
38template <typename Type>
39constexpr bool support_ic_parse =
40 is_instantiation_of<Type, solver::LowerTrs>::value ||
41 is_instantiation_of<Type, solver::Ir>::value ||
42 is_instantiation_of<Type, solver::Gmres>::value ||
43 is_instantiation_of<Type, preconditioner::LowerIsai>::value;
44
45
46template <
47 typename Ic,
48 std::enable_if_t<!support_ic_parse<typename Ic::l_solver_type>>* = nullptr>
49typename Ic::parameters_type ic_parse(
50 const config::pnode& config, const config::registry& context,
51 const config::type_descriptor& td_for_child)
52{
53 GKO_INVALID_STATE(
54 "preconditioner::Ic only supports limited type for parse.");
55}
56
57template <
58 typename Ic,
59 std::enable_if_t<support_ic_parse<typename Ic::l_solver_type>>* = nullptr>
60typename Ic::parameters_type ic_parse(
61 const config::pnode& config, const config::registry& context,
62 const config::type_descriptor& td_for_child);
63
64
65} // namespace detail
66
111template <typename LSolverType = solver::LowerTrs<>, typename IndexType = int32>
112class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {
113 friend class EnableLinOp<Ic>;
114 friend class EnablePolymorphicObject<Ic, LinOp>;
115
116public:
117 static_assert(
118 std::is_same<typename LSolverType::transposed_type::transposed_type,
119 LSolverType>::value,
120 "LSolverType::transposed_type must be symmetric");
121 using value_type = typename LSolverType::value_type;
122 using l_solver_type = LSolverType;
123 using lh_solver_type = typename LSolverType::transposed_type;
124 using index_type = IndexType;
126
127 class Factory;
128
130 : public enable_parameters_type<parameters_type, Factory> {
134 std::shared_ptr<const typename l_solver_type::Factory>
136
140 std::shared_ptr<const LinOpFactory> factorization_factory{};
141
142 GKO_DEPRECATED("use with_l_solver instead")
143 parameters_type& with_l_solver_factory(
144 deferred_factory_parameter<const typename l_solver_type::Factory>
145 solver)
146 {
147 return with_l_solver(std::move(solver));
148 }
149
150 parameters_type& with_l_solver(
152 solver)
153 {
154 this->l_solver_generator = std::move(solver);
155 this->deferred_factories["l_solver"] = [](const auto& exec,
156 auto& params) {
157 if (!params.l_solver_generator.is_empty()) {
158 params.l_solver_factory =
159 params.l_solver_generator.on(exec);
160 }
161 };
162 return *this;
163 }
164
165 GKO_DEPRECATED("use with_factorization instead")
166 parameters_type& with_factorization_factory(
167 deferred_factory_parameter<const LinOpFactory> factorization)
168 {
169 return with_factorization(std::move(factorization));
170 }
171
172 parameters_type& with_factorization(
173 deferred_factory_parameter<const LinOpFactory> factorization)
174 {
175 this->factorization_generator = std::move(factorization);
176 this->deferred_factories["factorization"] = [](const auto& exec,
177 auto& params) {
178 if (!params.factorization_generator.is_empty()) {
179 params.factorization_factory =
180 params.factorization_generator.on(exec);
181 }
182 };
183 return *this;
184 }
185
186 private:
187 deferred_factory_parameter<const typename l_solver_type::Factory>
188 l_solver_generator;
189
190 deferred_factory_parameter<const LinOpFactory> factorization_generator;
191 };
192
195
213 const config::pnode& config, const config::registry& context,
214 const config::type_descriptor& td_for_child =
215 config::make_type_descriptor<value_type, index_type>())
216 {
217 return detail::ic_parse<Ic>(config, context, td_for_child);
218 }
219
225 std::shared_ptr<const l_solver_type> get_l_solver() const
226 {
227 return l_solver_;
228 }
229
235 std::shared_ptr<const lh_solver_type> get_lh_solver() const
236 {
237 return lh_solver_;
238 }
239
240 std::unique_ptr<LinOp> transpose() const override
241 {
242 std::unique_ptr<transposed_type> transposed{
243 new transposed_type{this->get_executor()}};
244 transposed->set_size(gko::transpose(this->get_size()));
245 transposed->l_solver_ =
246 share(as<typename lh_solver_type::transposed_type>(
247 this->get_lh_solver()->transpose()));
248 transposed->lh_solver_ =
249 share(as<typename l_solver_type::transposed_type>(
250 this->get_l_solver()->transpose()));
251
252 return std::move(transposed);
253 }
254
255 std::unique_ptr<LinOp> conj_transpose() const override
256 {
257 std::unique_ptr<transposed_type> transposed{
258 new transposed_type{this->get_executor()}};
259 transposed->set_size(gko::transpose(this->get_size()));
260 transposed->l_solver_ =
261 share(as<typename lh_solver_type::transposed_type>(
262 this->get_lh_solver()->conj_transpose()));
263 transposed->lh_solver_ =
264 share(as<typename l_solver_type::transposed_type>(
265 this->get_l_solver()->conj_transpose()));
266
267 return std::move(transposed);
268 }
269
275 Ic& operator=(const Ic& other)
276 {
277 if (&other != this) {
279 auto exec = this->get_executor();
280 l_solver_ = other.l_solver_;
281 lh_solver_ = other.lh_solver_;
282 parameters_ = other.parameters_;
283 if (other.get_executor() != exec) {
284 l_solver_ = gko::clone(exec, l_solver_);
285 lh_solver_ = gko::clone(exec, lh_solver_);
286 }
287 }
288 return *this;
289 }
290
297 Ic& operator=(Ic&& other)
298 {
299 if (&other != this) {
301 auto exec = this->get_executor();
302 l_solver_ = std::move(other.l_solver_);
303 lh_solver_ = std::move(other.lh_solver_);
304 parameters_ = std::exchange(other.parameters_, parameters_type{});
305 if (other.get_executor() != exec) {
306 l_solver_ = gko::clone(exec, l_solver_);
307 lh_solver_ = gko::clone(exec, lh_solver_);
308 }
309 }
310 return *this;
311 }
312
317 Ic(const Ic& other) : Ic{other.get_executor()} { *this = other; }
318
324 Ic(Ic&& other) : Ic{other.get_executor()} { *this = std::move(other); }
325
326protected:
327 void apply_impl(const LinOp* b, LinOp* x) const override
328 {
329 // take care of real-to-complex apply
330 precision_dispatch_real_complex<value_type>(
331 [&](auto dense_b, auto dense_x) {
332 this->set_cache_to(dense_b);
333 l_solver_->apply(dense_b, cache_.intermediate);
334 if (lh_solver_->apply_uses_initial_guess()) {
335 dense_x->copy_from(cache_.intermediate);
336 }
337 lh_solver_->apply(cache_.intermediate, dense_x);
338 },
339 b, x);
340 }
341
342 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
343 LinOp* x) const override
344 {
345 precision_dispatch_real_complex<value_type>(
346 [&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
347 this->set_cache_to(dense_b);
348 l_solver_->apply(dense_b, cache_.intermediate);
349 lh_solver_->apply(dense_alpha, cache_.intermediate, dense_beta,
350 dense_x);
351 },
352 alpha, b, beta, x);
353 }
354
355 explicit Ic(std::shared_ptr<const Executor> exec)
356 : EnableLinOp<Ic>(std::move(exec))
357 {}
358
359 explicit Ic(const Factory* factory, std::shared_ptr<const LinOp> lin_op)
360 : EnableLinOp<Ic>(factory->get_executor(), lin_op->get_size()),
361 parameters_{factory->get_parameters()}
362 {
363 auto comp =
364 std::dynamic_pointer_cast<const Composition<value_type>>(lin_op);
365 std::shared_ptr<const LinOp> l_factor;
366
367 // build factorization if we weren't passed a composition
368 if (!comp) {
369 auto exec = lin_op->get_executor();
370 if (!parameters_.factorization_factory) {
371 parameters_.factorization_factory =
372 factorization::ParIc<value_type, index_type>::build()
373 .with_both_factors(false)
374 .on(exec);
375 }
376 auto fact = std::shared_ptr<const LinOp>(
377 parameters_.factorization_factory->generate(lin_op));
378 // ensure that the result is a composition
379 comp =
380 std::dynamic_pointer_cast<const Composition<value_type>>(fact);
381 if (!comp) {
382 GKO_NOT_SUPPORTED(comp);
383 }
384 }
385 // comp must contain one or two factors
386 if (comp->get_operators().size() > 2 || comp->get_operators().empty()) {
387 GKO_NOT_SUPPORTED(comp);
388 }
389 l_factor = comp->get_operators()[0];
390 GKO_ASSERT_IS_SQUARE_MATRIX(l_factor);
391
392 auto exec = this->get_executor();
393
394 // If no factories are provided, generate default ones
395 if (!parameters_.l_solver_factory) {
396 l_solver_ = generate_default_solver<l_solver_type>(exec, l_factor);
397 // If comp contains both factors: use the transposed factor to avoid
398 // transposing twice
399 if (comp->get_operators().size() == 2) {
400 auto lh_factor = comp->get_operators()[1];
401 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, lh_factor);
402 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
403 } else {
404 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
405 }
406 } else {
407 l_solver_ = parameters_.l_solver_factory->generate(l_factor);
408 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
409 }
410 }
411
419 void set_cache_to(const LinOp* b) const
420 {
421 if (cache_.intermediate == nullptr) {
422 cache_.intermediate =
424 }
425 // Use b as the initial guess for the first triangular solve
426 cache_.intermediate->copy_from(b);
427 }
428
429
437 template <typename SolverType>
438 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
439 std::unique_ptr<SolverType>>
440 generate_default_solver(const std::shared_ptr<const Executor>& exec,
441 const std::shared_ptr<const LinOp>& mtx)
442 {
443 const gko::remove_complex<value_type> default_reduce_residual{1e-4};
444 const unsigned int default_max_iters{
445 static_cast<unsigned int>(mtx->get_size()[0])};
446
447 return SolverType::build()
448 .with_criteria(
449 gko::stop::Iteration::build().with_max_iters(default_max_iters),
451 .with_reduction_factor(default_reduce_residual))
452 .on(exec)
453 ->generate(mtx);
454 }
455
459 template <typename SolverType>
460 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
461 std::unique_ptr<SolverType>>
462 generate_default_solver(const std::shared_ptr<const Executor>& exec,
463 const std::shared_ptr<const LinOp>& mtx)
464 {
465 return SolverType::build().on(exec)->generate(mtx);
466 }
467
468private:
469 std::shared_ptr<const l_solver_type> l_solver_{};
470 std::shared_ptr<const lh_solver_type> lh_solver_{};
481 mutable struct cache_struct {
482 cache_struct() = default;
483 ~cache_struct() = default;
484 cache_struct(const cache_struct&) {}
485 cache_struct(cache_struct&&) {}
486 cache_struct& operator=(const cache_struct&) { return *this; }
487 cache_struct& operator=(cache_struct&&) { return *this; }
488 std::unique_ptr<LinOp> intermediate{};
489 } cache_;
490};
491
492
493} // namespace preconditioner
494} // namespace gko
495
496
497#endif // GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
Definition lin_op.hpp:878
Definition polymorphic_object.hpp:668
Definition lin_op.hpp:117
std::shared_ptr< const Executor > get_executor() const noexcept
Definition polymorphic_object.hpp:243
Definition lin_op.hpp:433
Definition property_tree.hpp:28
Definition registry.hpp:167
Definition type_descriptor.hpp:39
Definition abstract_factory.hpp:309
Definition abstract_factory.hpp:211
std::unordered_map< std::string, std::function< void(std::shared_ptr< const Executor > exec, parameters_type &)> > deferred_factories
Definition abstract_factory.hpp:263
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
Definition ic.hpp:112
static std::enable_if_t< solver::has_with_criteria< SolverType >::value, std::unique_ptr< SolverType > > generate_default_solver(const std::shared_ptr< const Executor > &exec, const std::shared_ptr< const LinOp > &mtx)
Definition ic.hpp:440
std::shared_ptr< const lh_solver_type > get_lh_solver() const
Definition ic.hpp:235
std::unique_ptr< LinOp > transpose() const override
Definition ic.hpp:240
Ic(const Ic &other)
Definition ic.hpp:317
Ic & operator=(Ic &&other)
Definition ic.hpp:297
Ic(Ic &&other)
Definition ic.hpp:324
static parameters_type parse(const config::pnode &config, const config::registry &context, const config::type_descriptor &td_for_child=config::make_type_descriptor< value_type, index_type >())
Definition ic.hpp:212
void set_cache_to(const LinOp *b) const
Definition ic.hpp:419
std::shared_ptr< const l_solver_type > get_l_solver() const
Definition ic.hpp:225
static std::enable_if_t<!solver::has_with_criteria< SolverType >::value, std::unique_ptr< SolverType > > generate_default_solver(const std::shared_ptr< const Executor > &exec, const std::shared_ptr< const LinOp > &mtx)
Definition ic.hpp:462
Ic & operator=(const Ic &other)
Definition ic.hpp:275
std::unique_ptr< LinOp > conj_transpose() const override
Definition ic.hpp:255
Definition residual_norm.hpp:113
#define GKO_ENABLE_BUILD_METHOD(_factory_name)
Definition abstract_factory.hpp:394
#define GKO_ENABLE_LIN_OP_FACTORY(_lin_op, _parameters_name, _factory_name)
Definition lin_op.hpp:1016
The Ginkgo namespace.
Definition abstract_factory.hpp:20
detail::cloned_type< Pointer > clone(const Pointer &p)
Definition utils_helper.hpp:173
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Definition batch_dim.hpp:119
detail::shared_type< OwningPointer > share(OwningPointer &&p)
Definition utils_helper.hpp:224
typename detail::remove_complex_s< T >::type remove_complex
Definition math.hpp:260
STL namespace.
std::shared_ptr< const typename l_solver_type::Factory > l_solver_factory
Definition ic.hpp:135
std::shared_ptr< const LinOpFactory > factorization_factory
Definition ic.hpp:140