ginkgo/core/preconditioner/ilu.hpp Source File

ginkgo/core/preconditioner/ilu.hpp Source File#

Reference API: ginkgo/core/preconditioner/ilu.hpp Source File
Reference API
ilu.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
6#define GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
7
8
9#include <memory>
10#include <type_traits>
11
12#include <ginkgo/core/base/abstract_factory.hpp>
13#include <ginkgo/core/base/composition.hpp>
14#include <ginkgo/core/base/exception.hpp>
15#include <ginkgo/core/base/exception_helpers.hpp>
16#include <ginkgo/core/base/lin_op.hpp>
17#include <ginkgo/core/base/precision_dispatch.hpp>
18#include <ginkgo/core/config/config.hpp>
19#include <ginkgo/core/config/registry.hpp>
20#include <ginkgo/core/factorization/par_ilu.hpp>
21#include <ginkgo/core/matrix/dense.hpp>
22#include <ginkgo/core/preconditioner/isai.hpp>
23#include <ginkgo/core/preconditioner/utils.hpp>
24#include <ginkgo/core/solver/gmres.hpp>
25#include <ginkgo/core/solver/ir.hpp>
26#include <ginkgo/core/solver/solver_traits.hpp>
27#include <ginkgo/core/solver/triangular.hpp>
28#include <ginkgo/core/stop/combined.hpp>
29#include <ginkgo/core/stop/iteration.hpp>
30#include <ginkgo/core/stop/residual_norm.hpp>
31
32
33namespace gko {
34namespace preconditioner {
35namespace detail {
36
37
38template <typename LSolverType, typename USolverType>
39constexpr bool support_ilu_parse =
40 std::is_same<typename USolverType::transposed_type, LSolverType>::value &&
41 (is_instantiation_of<LSolverType, solver::LowerTrs>::value ||
42 is_instantiation_of<LSolverType, solver::Ir>::value ||
43 is_instantiation_of<LSolverType, solver::Gmres>::value ||
44 is_instantiation_of<LSolverType, preconditioner::LowerIsai>::value);
45
46
47template <typename Ilu,
48 std::enable_if_t<!support_ilu_parse<typename Ilu::l_solver_type,
49 typename Ilu::u_solver_type>>* =
50 nullptr>
51typename Ilu::parameters_type ilu_parse(
52 const config::pnode& config, const config::registry& context,
53 const config::type_descriptor& td_for_child)
54{
55 GKO_INVALID_STATE(
56 "preconditioner::Ilu only supports limited type for parse.");
57}
58
59template <
60 typename Ilu,
61 std::enable_if_t<support_ilu_parse<typename Ilu::l_solver_type,
62 typename Ilu::u_solver_type>>* = nullptr>
63typename Ilu::parameters_type ilu_parse(
64 const config::pnode& config, const config::registry& context,
65 const config::type_descriptor& td_for_child);
66
67} // namespace detail
68
69
118template <typename LSolverType = solver::LowerTrs<>,
119 typename USolverType = solver::UpperTrs<>, bool ReverseApply = false,
120 typename IndexType = int32>
121class Ilu : public EnableLinOp<
122 Ilu<LSolverType, USolverType, ReverseApply, IndexType>>,
123 public Transposable {
124 friend class EnableLinOp<Ilu>;
125 friend class EnablePolymorphicObject<Ilu, LinOp>;
126
127public:
128 static_assert(
129 std::is_same<typename LSolverType::value_type,
130 typename USolverType::value_type>::value,
131 "Both the L- and the U-solver must use the same `value_type`!");
132 using value_type = typename LSolverType::value_type;
133 using l_solver_type = LSolverType;
134 using u_solver_type = USolverType;
135 static constexpr bool performs_reverse_apply = ReverseApply;
136 using index_type = IndexType;
137 using transposed_type =
138 Ilu<typename USolverType::transposed_type,
139 typename LSolverType::transposed_type, ReverseApply, IndexType>;
140
141 class Factory;
142
144 : public enable_parameters_type<parameters_type, Factory> {
148 std::shared_ptr<const typename l_solver_type::Factory>
150
154 std::shared_ptr<const typename u_solver_type::Factory>
156
160 std::shared_ptr<const LinOpFactory> factorization_factory{};
161
162 GKO_DEPRECATED("use with_l_solver instead")
163 parameters_type& with_l_solver_factory(
164 deferred_factory_parameter<const typename l_solver_type::Factory>
165 solver)
166 {
167 return with_l_solver(std::move(solver));
168 }
169
170 parameters_type& with_l_solver(
172 solver)
173 {
174 this->l_solver_generator = std::move(solver);
175 this->deferred_factories["l_solver"] = [](const auto& exec,
176 auto& params) {
177 if (!params.l_solver_generator.is_empty()) {
178 params.l_solver_factory =
179 params.l_solver_generator.on(exec);
180 }
181 };
182 return *this;
183 }
184
185 GKO_DEPRECATED("use with_u_solver instead")
186 parameters_type& with_u_solver_factory(
187 deferred_factory_parameter<const typename u_solver_type::Factory>
188 solver)
189 {
190 return with_u_solver(std::move(solver));
191 }
192
193 parameters_type& with_u_solver(
194 deferred_factory_parameter<const typename u_solver_type::Factory>
195 solver)
196 {
197 this->u_solver_generator = std::move(solver);
198 this->deferred_factories["u_solver"] = [](const auto& exec,
199 auto& params) {
200 if (!params.u_solver_generator.is_empty()) {
201 params.u_solver_factory =
202 params.u_solver_generator.on(exec);
203 }
204 };
205 return *this;
206 }
207
208 GKO_DEPRECATED("use with_factorization instead")
209 parameters_type& with_factorization_factory(
210 deferred_factory_parameter<const LinOpFactory> factorization)
211 {
212 return with_factorization(std::move(factorization));
213 }
214
215 parameters_type& with_factorization(
216 deferred_factory_parameter<const LinOpFactory> factorization)
217 {
218 this->factorization_generator = std::move(factorization);
219 this->deferred_factories["factorization"] = [](const auto& exec,
220 auto& params) {
221 if (!params.factorization_generator.is_empty()) {
222 params.factorization_factory =
223 params.factorization_generator.on(exec);
224 }
225 };
226 return *this;
227 }
228
229 private:
230 deferred_factory_parameter<const typename l_solver_type::Factory>
231 l_solver_generator;
232
233 deferred_factory_parameter<const typename u_solver_type::Factory>
234 u_solver_generator;
235
236 deferred_factory_parameter<const LinOpFactory> factorization_generator;
237 };
238
241
260 const config::pnode& config, const config::registry& context,
261 const config::type_descriptor& td_for_child =
262 config::make_type_descriptor<value_type, index_type>())
263 {
264 return detail::ilu_parse<Ilu>(config, context, td_for_child);
265 }
266
272 std::shared_ptr<const l_solver_type> get_l_solver() const
273 {
274 return l_solver_;
275 }
276
282 std::shared_ptr<const u_solver_type> get_u_solver() const
283 {
284 return u_solver_;
285 }
286
287 std::unique_ptr<LinOp> transpose() const override
288 {
289 std::unique_ptr<transposed_type> transposed{
290 new transposed_type{this->get_executor()}};
291 transposed->set_size(gko::transpose(this->get_size()));
292 transposed->l_solver_ =
293 share(as<typename u_solver_type::transposed_type>(
294 this->get_u_solver()->transpose()));
295 transposed->u_solver_ =
296 share(as<typename l_solver_type::transposed_type>(
297 this->get_l_solver()->transpose()));
298
299 return std::move(transposed);
300 }
301
302 std::unique_ptr<LinOp> conj_transpose() const override
303 {
304 std::unique_ptr<transposed_type> transposed{
305 new transposed_type{this->get_executor()}};
306 transposed->set_size(gko::transpose(this->get_size()));
307 transposed->l_solver_ =
308 share(as<typename u_solver_type::transposed_type>(
309 this->get_u_solver()->conj_transpose()));
310 transposed->u_solver_ =
311 share(as<typename l_solver_type::transposed_type>(
312 this->get_l_solver()->conj_transpose()));
313
314 return std::move(transposed);
315 }
316
322 Ilu& operator=(const Ilu& other)
323 {
324 if (&other != this) {
326 auto exec = this->get_executor();
327 l_solver_ = other.l_solver_;
328 u_solver_ = other.u_solver_;
329 parameters_ = other.parameters_;
330 if (other.get_executor() != exec) {
331 l_solver_ = gko::clone(exec, l_solver_);
332 u_solver_ = gko::clone(exec, u_solver_);
333 }
334 }
335 return *this;
336 }
337
344 Ilu& operator=(Ilu&& other)
345 {
346 if (&other != this) {
348 auto exec = this->get_executor();
349 l_solver_ = std::move(other.l_solver_);
350 u_solver_ = std::move(other.u_solver_);
351 parameters_ = std::exchange(other.parameters_, parameters_type{});
352 if (other.get_executor() != exec) {
353 l_solver_ = gko::clone(exec, l_solver_);
354 u_solver_ = gko::clone(exec, u_solver_);
355 }
356 }
357 return *this;
358 }
359
364 Ilu(const Ilu& other) : Ilu{other.get_executor()} { *this = other; }
365
371 Ilu(Ilu&& other) : Ilu{other.get_executor()} { *this = std::move(other); }
372
373protected:
374 void apply_impl(const LinOp* b, LinOp* x) const override
375 {
376 // take care of real-to-complex apply
377 precision_dispatch_real_complex<value_type>(
378 [&](auto dense_b, auto dense_x) {
379 this->set_cache_to(dense_b);
380 if (!ReverseApply) {
381 l_solver_->apply(dense_b, cache_.intermediate);
382 if (u_solver_->apply_uses_initial_guess()) {
383 dense_x->copy_from(cache_.intermediate);
384 }
385 u_solver_->apply(cache_.intermediate, dense_x);
386 } else {
387 u_solver_->apply(dense_b, cache_.intermediate);
388 if (l_solver_->apply_uses_initial_guess()) {
389 dense_x->copy_from(cache_.intermediate);
390 }
391 l_solver_->apply(cache_.intermediate, dense_x);
392 }
393 },
394 b, x);
395 }
396
397 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
398 LinOp* x) const override
399 {
400 precision_dispatch_real_complex<value_type>(
401 [&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
402 this->set_cache_to(dense_b);
403 if (!ReverseApply) {
404 l_solver_->apply(dense_b, cache_.intermediate);
405 u_solver_->apply(dense_alpha, cache_.intermediate,
406 dense_beta, dense_x);
407 } else {
408 u_solver_->apply(dense_b, cache_.intermediate);
409 l_solver_->apply(dense_alpha, cache_.intermediate,
410 dense_beta, dense_x);
411 }
412 },
413 alpha, b, beta, x);
414 }
415
416 explicit Ilu(std::shared_ptr<const Executor> exec)
417 : EnableLinOp<Ilu>(std::move(exec))
418 {}
419
420 explicit Ilu(const Factory* factory, std::shared_ptr<const LinOp> lin_op)
421 : EnableLinOp<Ilu>(factory->get_executor(), lin_op->get_size()),
422 parameters_{factory->get_parameters()}
423 {
424 auto comp =
425 std::dynamic_pointer_cast<const Composition<value_type>>(lin_op);
426 std::shared_ptr<const LinOp> l_factor;
427 std::shared_ptr<const LinOp> u_factor;
428
429 // build factorization if we weren't passed a composition
430 if (!comp) {
431 auto exec = lin_op->get_executor();
432 if (!parameters_.factorization_factory) {
433 parameters_.factorization_factory =
434 factorization::ParIlu<value_type, index_type>::build().on(
435 exec);
436 }
437 auto fact = std::shared_ptr<const LinOp>(
438 parameters_.factorization_factory->generate(lin_op));
439 // ensure that the result is a composition
440 comp =
441 std::dynamic_pointer_cast<const Composition<value_type>>(fact);
442 if (!comp) {
443 GKO_NOT_SUPPORTED(comp);
444 }
445 }
446 if (comp->get_operators().size() == 2) {
447 l_factor = comp->get_operators()[0];
448 u_factor = comp->get_operators()[1];
449 } else {
450 GKO_NOT_SUPPORTED(comp);
451 }
452 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, u_factor);
453
454 auto exec = this->get_executor();
455
456 // If no factories are provided, generate default ones
457 if (!parameters_.l_solver_factory) {
458 l_solver_ = generate_default_solver<l_solver_type>(exec, l_factor);
459 } else {
460 l_solver_ = parameters_.l_solver_factory->generate(l_factor);
461 }
462 if (!parameters_.u_solver_factory) {
463 u_solver_ = generate_default_solver<u_solver_type>(exec, u_factor);
464 } else {
465 u_solver_ = parameters_.u_solver_factory->generate(u_factor);
466 }
467 }
468
476 void set_cache_to(const LinOp* b) const
477 {
478 if (cache_.intermediate == nullptr) {
479 cache_.intermediate =
481 }
482 // Use b as the initial guess for the first triangular solve
483 cache_.intermediate->copy_from(b);
484 }
485
486
494 template <typename SolverType>
495 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
496 std::unique_ptr<SolverType>>
497 generate_default_solver(const std::shared_ptr<const Executor>& exec,
498 const std::shared_ptr<const LinOp>& mtx)
499 {
500 // half can not use constexpr constructor
501 const gko::remove_complex<value_type> default_reduce_residual{1e-4};
502 const unsigned int default_max_iters{
503 static_cast<unsigned int>(mtx->get_size()[0])};
504
505 return SolverType::build()
506 .with_criteria(
507 gko::stop::Iteration::build().with_max_iters(default_max_iters),
509 .with_reduction_factor(default_reduce_residual))
510 .on(exec)
511 ->generate(mtx);
512 }
513
517 template <typename SolverType>
518 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
519 std::unique_ptr<SolverType>>
520 generate_default_solver(const std::shared_ptr<const Executor>& exec,
521 const std::shared_ptr<const LinOp>& mtx)
522 {
523 return SolverType::build().on(exec)->generate(mtx);
524 }
525
526private:
527 std::shared_ptr<const l_solver_type> l_solver_{};
528 std::shared_ptr<const u_solver_type> u_solver_{};
539 mutable struct cache_struct {
540 cache_struct() = default;
541 ~cache_struct() = default;
542 cache_struct(const cache_struct&) {}
543 cache_struct(cache_struct&&) {}
544 cache_struct& operator=(const cache_struct&) { return *this; }
545 cache_struct& operator=(cache_struct&&) { return *this; }
546 std::unique_ptr<LinOp> intermediate{};
547 } cache_;
548};
549
550
551} // namespace preconditioner
552} // namespace gko
553
554
555#endif // GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
Definition lin_op.hpp:878
Definition polymorphic_object.hpp:668
Definition lin_op.hpp:117
std::shared_ptr< const Executor > get_executor() const noexcept
Definition polymorphic_object.hpp:243
Definition lin_op.hpp:433
Definition property_tree.hpp:28
Definition registry.hpp:167
Definition type_descriptor.hpp:39
Definition abstract_factory.hpp:309
Definition abstract_factory.hpp:211
std::unordered_map< std::string, std::function< void(std::shared_ptr< const Executor > exec, parameters_type &)> > deferred_factories
Definition abstract_factory.hpp:263
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
Definition ilu.hpp:239
Definition ilu.hpp:123
Ilu(Ilu &&other)
Definition ilu.hpp:371
static std::enable_if_t<!solver::has_with_criteria< SolverType >::value, std::unique_ptr< SolverType > > generate_default_solver(const std::shared_ptr< const Executor > &exec, const std::shared_ptr< const LinOp > &mtx)
Definition ilu.hpp:520
Ilu & operator=(Ilu &&other)
Definition ilu.hpp:344
static std::enable_if_t< solver::has_with_criteria< SolverType >::value, std::unique_ptr< SolverType > > generate_default_solver(const std::shared_ptr< const Executor > &exec, const std::shared_ptr< const LinOp > &mtx)
Definition ilu.hpp:497
std::shared_ptr< const l_solver_type > get_l_solver() const
Definition ilu.hpp:272
static parameters_type parse(const config::pnode &config, const config::registry &context, const config::type_descriptor &td_for_child=config::make_type_descriptor< value_type, index_type >())
Definition ilu.hpp:259
std::unique_ptr< LinOp > conj_transpose() const override
Definition ilu.hpp:302
Ilu(const Ilu &other)
Definition ilu.hpp:364
void set_cache_to(const LinOp *b) const
Definition ilu.hpp:476
std::shared_ptr< const u_solver_type > get_u_solver() const
Definition ilu.hpp:282
std::unique_ptr< LinOp > transpose() const override
Definition ilu.hpp:287
Ilu & operator=(const Ilu &other)
Definition ilu.hpp:322
Definition residual_norm.hpp:113
#define GKO_ENABLE_BUILD_METHOD(_factory_name)
Definition abstract_factory.hpp:394
#define GKO_ENABLE_LIN_OP_FACTORY(_lin_op, _parameters_name, _factory_name)
Definition lin_op.hpp:1016
The Ginkgo namespace.
Definition abstract_factory.hpp:20
detail::cloned_type< Pointer > clone(const Pointer &p)
Definition utils_helper.hpp:173
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Definition batch_dim.hpp:119
detail::shared_type< OwningPointer > share(OwningPointer &&p)
Definition utils_helper.hpp:224
typename detail::remove_complex_s< T >::type remove_complex
Definition math.hpp:260
STL namespace.
std::shared_ptr< const LinOpFactory > factorization_factory
Definition ilu.hpp:160
std::shared_ptr< const typename u_solver_type::Factory > u_solver_factory
Definition ilu.hpp:155
std::shared_ptr< const typename l_solver_type::Factory > l_solver_factory
Definition ilu.hpp:149