Ginkgo Generated from branch based on master. Ginkgo version 1.7.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
batch_lin_op.hpp
1/*******************************<GINKGO LICENSE>******************************
2Copyright (c) 2017-2023, the Ginkgo authors
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions
7are met:
8
91. Redistributions of source code must retain the above copyright
10notice, this list of conditions and the following disclaimer.
11
122. Redistributions in binary form must reproduce the above copyright
13notice, this list of conditions and the following disclaimer in the
14documentation and/or other materials provided with the distribution.
15
163. Neither the name of the copyright holder nor the names of its
17contributors may be used to endorse or promote products derived from
18this software without specific prior written permission.
19
20THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31******************************<GINKGO LICENSE>*******************************/
32
33#ifndef GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
34#define GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
35
36
37#include <memory>
38#include <type_traits>
39#include <utility>
40
41
42#include <ginkgo/core/base/abstract_factory.hpp>
43#include <ginkgo/core/base/batch_multi_vector.hpp>
44#include <ginkgo/core/base/dim.hpp>
45#include <ginkgo/core/base/exception_helpers.hpp>
46#include <ginkgo/core/base/math.hpp>
47#include <ginkgo/core/base/matrix_assembly_data.hpp>
48#include <ginkgo/core/base/matrix_data.hpp>
49#include <ginkgo/core/base/polymorphic_object.hpp>
50#include <ginkgo/core/base/types.hpp>
51#include <ginkgo/core/base/utils.hpp>
52#include <ginkgo/core/log/logger.hpp>
53
54
55namespace gko {
56namespace batch {
57
58
88class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
89public:
99
106
112 const batch_dim<2>& get_size() const noexcept { return size_; }
113
119 template <typename ValueType>
121 MultiVector<ValueType>* x) const
122 {
123 GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
124 GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
125
126 GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
127 GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
128 GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
129 }
130
136 template <typename ValueType>
138 const MultiVector<ValueType>* b,
139 const MultiVector<ValueType>* beta,
140 MultiVector<ValueType>* x) const
141 {
142 GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
143 GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
144
145 GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
146 GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
147 GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
148 GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(),
149 gko::dim<2>(1, 1));
150 GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1));
151 }
152
153protected:
159 void set_size(const batch_dim<2>& size) { size_ = size; }
160
167 explicit BatchLinOp(std::shared_ptr<const Executor> exec,
169 : EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
170 {}
171
180 explicit BatchLinOp(std::shared_ptr<const Executor> exec,
181 const size_type num_batch_items = 0,
182 const dim<2>& common_size = dim<2>{})
183 : BatchLinOp{std::move(exec),
185 ? batch_dim<2>(num_batch_items, common_size)
186 : batch_dim<2>{}}
187 {}
188
189private:
190 batch_dim<2> size_{};
191};
192
193
224 : public AbstractFactory<BatchLinOp, std::shared_ptr<const BatchLinOp>> {
225public:
227 std::shared_ptr<const BatchLinOp>>::AbstractFactory;
228
229 std::unique_ptr<BatchLinOp> generate(
230 std::shared_ptr<const BatchLinOp> input) const
231 {
232 this->template log<
233 gko::log::Logger::batch_linop_factory_generate_started>(
234 this, input.get());
235 const auto exec = this->get_executor();
236 std::unique_ptr<BatchLinOp> generated;
237 if (input->get_executor() == exec) {
238 generated = this->AbstractFactory::generate(input);
239 } else {
240 generated =
241 this->AbstractFactory::generate(gko::clone(exec, input));
242 }
243 this->template log<
244 gko::log::Logger::batch_linop_factory_generate_completed>(
245 this, input.get(), generated.get());
246 return generated;
247 }
248};
249
250
278template <typename ConcreteBatchLinOp, typename PolymorphicBase = BatchLinOp>
280 : public EnablePolymorphicObject<ConcreteBatchLinOp, PolymorphicBase>,
281 public EnablePolymorphicAssignment<ConcreteBatchLinOp> {
282public:
283 using EnablePolymorphicObject<ConcreteBatchLinOp,
284 PolymorphicBase>::EnablePolymorphicObject;
285};
286
287
304template <typename ConcreteFactory, typename ConcreteBatchLinOp,
305 typename ParametersType, typename PolymorphicBase = BatchLinOpFactory>
308 PolymorphicBase>;
309
310
387#define GKO_ENABLE_BATCH_LIN_OP_FACTORY(_batch_lin_op, _parameters_name, \
388 _factory_name) \
389public: \
390 const _parameters_name##_type& get_##_parameters_name() const \
391 { \
392 return _parameters_name##_; \
393 } \
394 \
395 class _factory_name \
396 : public ::gko::batch::EnableDefaultBatchLinOpFactory< \
397 _factory_name, _batch_lin_op, _parameters_name##_type> { \
398 friend class ::gko::EnablePolymorphicObject< \
399 _factory_name, ::gko::batch::BatchLinOpFactory>; \
400 friend class ::gko::enable_parameters_type<_parameters_name##_type, \
401 _factory_name>; \
402 explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec) \
403 : ::gko::batch::EnableDefaultBatchLinOpFactory< \
404 _factory_name, _batch_lin_op, _parameters_name##_type>( \
405 std::move(exec)) \
406 {} \
407 explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec, \
408 const _parameters_name##_type& parameters) \
409 : ::gko::batch::EnableDefaultBatchLinOpFactory< \
410 _factory_name, _batch_lin_op, _parameters_name##_type>( \
411 std::move(exec), parameters) \
412 {} \
413 }; \
414 friend ::gko::batch::EnableDefaultBatchLinOpFactory< \
415 _factory_name, _batch_lin_op, _parameters_name##_type>; \
416 \
417 \
418private: \
419 _parameters_name##_type _parameters_name##_; \
420 \
421public: \
422 static_assert(true, \
423 "This assert is used to counter the false positive extra " \
424 "semi-colon warnings")
425
426
427} // namespace batch
428} // namespace gko
429
430
431#endif // GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
The AbstractFactory is a generic interface template that enables easy implementation of the abstract ...
Definition abstract_factory.hpp:75
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:374
This mixin provides a default implementation of a concrete factory.
Definition abstract_factory.hpp:154
This mixin is used to enable a default PolymorphicObject::copy_from() implementation for objects that...
Definition polymorphic_object.hpp:752
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:691
A BatchLinOpFactory represents a higher order mapping which transforms one batch linear operator into...
Definition batch_lin_op.hpp:224
Definition batch_lin_op.hpp:88
const batch_dim< 2 > & get_size() const noexcept
Returns the size of the batch operator.
Definition batch_lin_op.hpp:112
void validate_application_parameters(const MultiVector< ValueType > *b, MultiVector< ValueType > *x) const
Validates the sizes for the apply(b,x) operation in the concrete BatchLinOp.
Definition batch_lin_op.hpp:120
void validate_application_parameters(const MultiVector< ValueType > *alpha, const MultiVector< ValueType > *b, const MultiVector< ValueType > *beta, MultiVector< ValueType > *x) const
Validates the sizes for the apply(alpha, b , beta, x) operation in the concrete BatchLinOp.
Definition batch_lin_op.hpp:137
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_lin_op.hpp:105
size_type get_num_batch_items() const noexcept
Returns the number of items in the batch operator.
Definition batch_lin_op.hpp:95
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:281
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition batch_multi_vector.hpp:85
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:157
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:147
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:203
A type representing the dimensions of a multidimensional batch object.
Definition batch_dim.hpp:56
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition batch_dim.hpp:72
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition batch_dim.hpp:65
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:55