Ginkgo Generated from branch based on master. Ginkgo version 1.7.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
batch_multi_vector.hpp
1/*******************************<GINKGO LICENSE>******************************
2Copyright (c) 2017-2023, the Ginkgo authors
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions
7are met:
8
91. Redistributions of source code must retain the above copyright
10notice, this list of conditions and the following disclaimer.
11
122. Redistributions in binary form must reproduce the above copyright
13notice, this list of conditions and the following disclaimer in the
14documentation and/or other materials provided with the distribution.
15
163. Neither the name of the copyright holder nor the names of its
17contributors may be used to endorse or promote products derived from
18this software without specific prior written permission.
19
20THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31******************************<GINKGO LICENSE>*******************************/
32
33#ifndef GKO_PUBLIC_CORE_BASE_BATCH_MULTI_VECTOR_HPP_
34#define GKO_PUBLIC_CORE_BASE_BATCH_MULTI_VECTOR_HPP_
35
36
37#include <initializer_list>
38#include <vector>
39
40
41#include <ginkgo/core/base/array.hpp>
42#include <ginkgo/core/base/batch_dim.hpp>
43#include <ginkgo/core/base/dim.hpp>
44#include <ginkgo/core/base/executor.hpp>
45#include <ginkgo/core/base/mtx_io.hpp>
46#include <ginkgo/core/base/polymorphic_object.hpp>
47#include <ginkgo/core/base/range_accessors.hpp>
48#include <ginkgo/core/base/types.hpp>
49#include <ginkgo/core/base/utils.hpp>
50#include <ginkgo/core/matrix/dense.hpp>
51
52
53namespace gko {
54namespace batch {
55
56
80template <typename ValueType = default_precision>
82 : public EnablePolymorphicObject<MultiVector<ValueType>>,
83 public EnablePolymorphicAssignment<MultiVector<ValueType>>,
84 public EnableCreateMethod<MultiVector<ValueType>>,
85 public ConvertibleTo<MultiVector<next_precision<ValueType>>> {
86 friend class EnableCreateMethod<MultiVector>;
88 friend class MultiVector<to_complex<ValueType>>;
89 friend class MultiVector<next_precision<ValueType>>;
90
91public:
96
97 using value_type = ValueType;
98 using index_type = int32;
100 using absolute_type = remove_complex<MultiVector<ValueType>>;
101 using complex_type = to_complex<MultiVector<ValueType>>;
102
109 static std::unique_ptr<MultiVector> create_with_config_of(
111
112 void convert_to(
114
115 void move_to(MultiVector<next_precision<ValueType>>* result) override;
116
127 std::unique_ptr<unbatch_type> create_view_for_item(size_type item_id);
128
132 std::unique_ptr<const unbatch_type> create_const_view_for_item(
133 size_type item_id) const;
134
140 batch_dim<2> get_size() const { return batch_size_; }
141
148 {
149 return batch_size_.get_num_batch_items();
150 }
151
157 dim<2> get_common_size() const { return batch_size_.get_common_size(); }
158
164 value_type* get_values() noexcept { return values_.get_data(); }
165
173 const value_type* get_const_values() const noexcept
174 {
175 return values_.get_const_data();
176 }
177
187 {
189 return values_.get_data() + this->get_cumulative_offset(batch_id);
190 }
191
199 const value_type* get_const_values_for_item(
200 size_type batch_id) const noexcept
201 {
203 return values_.get_const_data() + this->get_cumulative_offset(batch_id);
204 }
205
214 {
215 return values_.get_num_elems();
216 }
217
226 {
227 return batch_id * this->get_common_size()[0] *
228 this->get_common_size()[1];
229 }
230
242 value_type& at(size_type batch_id, size_type row, size_type col)
243 {
245 return values_.get_data()[linearize_index(batch_id, row, col)];
246 }
247
251 value_type at(size_type batch_id, size_type row, size_type col) const
252 {
254 return values_.get_const_data()[linearize_index(batch_id, row, col)];
255 }
256
271 ValueType& at(size_type batch_id, size_type idx) noexcept
272 {
273 return values_.get_data()[linearize_index(batch_id, idx)];
274 }
275
279 ValueType at(size_type batch_id, size_type idx) const noexcept
280 {
281 return values_.get_const_data()[linearize_index(batch_id, idx)];
282 }
283
296
311
322
335
345
359 static std::unique_ptr<const MultiVector<ValueType>> create_const(
360 std::shared_ptr<const Executor> exec, const batch_dim<2>& sizes,
361 gko::detail::const_array_view<ValueType>&& values);
362
368 void fill(ValueType value);
369
370private:
371 inline size_type compute_num_elems(const batch_dim<2>& size)
372 {
373 return size.get_num_batch_items() * size.get_common_size()[0] *
374 size.get_common_size()[1];
375 }
376
377protected:
383 void set_size(const batch_dim<2>& value) noexcept;
384
392 MultiVector(std::shared_ptr<const Executor> exec,
393 const batch_dim<2>& size = batch_dim<2>{});
394
409 template <typename ValuesArray>
410 MultiVector(std::shared_ptr<const Executor> exec, const batch_dim<2>& size,
411 ValuesArray&& values)
412 : EnablePolymorphicObject<MultiVector<ValueType>>(exec),
413 batch_size_(size),
414 values_{exec, std::forward<ValuesArray>(values)}
415 {
416 // Ensure that the values array has the correct size
417 auto num_elems = compute_num_elems(size);
418 GKO_ENSURE_IN_BOUNDS(num_elems, values_.get_num_elems() + 1);
419 }
420
428 std::unique_ptr<MultiVector> create_with_same_config() const;
429
430 size_type linearize_index(size_type batch, size_type row,
431 size_type col) const noexcept
432 {
433 return this->get_cumulative_offset(batch) +
434 row * batch_size_.get_common_size()[1] + col;
435 }
436
437 size_type linearize_index(size_type batch, size_type idx) const noexcept
438 {
439 return linearize_index(batch, idx / this->get_common_size()[1],
440 idx % this->get_common_size()[1]);
441 }
442
443private:
444 batch_dim<2> batch_size_;
445 array<value_type> values_;
446};
447
448
449} // namespace batch
450} // namespace gko
451
452
453#endif // GKO_PUBLIC_CORE_BASE_BATCH_MULTI_VECTOR_HPP_
ConvertibleTo interface is used to mark that the implementer can be converted to the object of Result...
Definition polymorphic_object.hpp:499
This mixin implements a static create() method on ConcreteType that dynamically allocates the memory,...
Definition polymorphic_object.hpp:776
This mixin is used to enable a default PolymorphicObject::copy_from() implementation for objects that...
Definition polymorphic_object.hpp:752
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:691
value_type * get_data() noexcept
Returns a pointer to the block of memory used to store the elements of the array.
Definition array.hpp:646
const value_type * get_const_data() const noexcept
Returns a constant pointer to the block of memory used to store the elements of the array.
Definition array.hpp:655
size_type get_num_elems() const noexcept
Returns the number of elements in the array.
Definition array.hpp:637
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition batch_multi_vector.hpp:85
value_type * get_values_for_item(size_type batch_id) noexcept
Returns a pointer to the array of values of the multi-vector for a specific batch item.
Definition batch_multi_vector.hpp:186
void compute_conj_dot(ptr_param< const MultiVector< ValueType > > b, ptr_param< MultiVector< ValueType > > result) const
Computes the column-wise conjugate dot product of each multi-vector in this batch and its correspondi...
size_type get_cumulative_offset(size_type batch_id) const
Get the cumulative storage size offset.
Definition batch_multi_vector.hpp:225
const value_type * get_const_values_for_item(size_type batch_id) const noexcept
Returns a pointer to the array of values of the multi-vector for a specific batch item.
Definition batch_multi_vector.hpp:199
void scale(ptr_param< const MultiVector< ValueType > > alpha)
Scales the vector with a scalar (aka: BLAS scal).
value_type * get_values() noexcept
Returns a pointer to the array of values of the multi-vector.
Definition batch_multi_vector.hpp:164
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:157
void compute_dot(ptr_param< const MultiVector< ValueType > > b, ptr_param< MultiVector< ValueType > > result) const
Computes the column-wise dot product of each multi-vector in this batch and its corresponding entry i...
void fill(ValueType value)
Fills the input MultiVector with a given value.
static std::unique_ptr< const MultiVector< ValueType > > create_const(std::shared_ptr< const Executor > exec, const batch_dim< 2 > &sizes, gko::detail::const_array_view< ValueType > &&values)
Creates a constant (immutable) batch multi-vector from a constant array.
ValueType & at(size_type batch_id, size_type idx) noexcept
Returns a single element for a particular batch item.
Definition batch_multi_vector.hpp:271
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:147
size_type get_num_stored_elements() const noexcept
Returns the number of elements explicitly stored in the batch matrix, cumulative across all the batch...
Definition batch_multi_vector.hpp:213
ValueType at(size_type batch_id, size_type idx) const noexcept
Returns a single element for a particular batch item.
Definition batch_multi_vector.hpp:279
batch_dim< 2 > get_size() const
Returns the batch size.
Definition batch_multi_vector.hpp:140
std::unique_ptr< unbatch_type > create_view_for_item(size_type item_id)
Creates a mutable view (of matrix::Dense type) of one item of the Batch MultiVector object.
const value_type * get_const_values() const noexcept
Returns a pointer to the array of values of the multi-vector.
Definition batch_multi_vector.hpp:173
void add_scaled(ptr_param< const MultiVector< ValueType > > alpha, ptr_param< const MultiVector< ValueType > > b)
Adds b scaled by alpha to the vector (aka: BLAS axpy).
void compute_norm2(ptr_param< MultiVector< remove_complex< ValueType > > > result) const
Computes the Euclidean (L^2) norm of each multi-vector in this batch.
value_type at(size_type batch_id, size_type row, size_type col) const
Returns a single element for a particular batch item.
Definition batch_multi_vector.hpp:251
static std::unique_ptr< MultiVector > create_with_config_of(ptr_param< const MultiVector > other)
Creates a MultiVector with the configuration of another MultiVector.
value_type & at(size_type batch_id, size_type row, size_type col)
Returns a single element for a particular batch item.
Definition batch_multi_vector.hpp:242
std::unique_ptr< const unbatch_type > create_const_view_for_item(size_type item_id) const
Creates a mutable view (of matrix::Dense type) of one item of the Batch MultiVector object.
Dense is a matrix format which explicitly stores all values of the matrix.
Definition dense.hpp:136
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:71
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:354
std::int32_t int32
32-bit signed integral type.
Definition types.hpp:137
typename detail::next_precision_impl< T >::type next_precision
Obtains the next type in the singly-linked precision list.
Definition math.hpp:490
typename detail::to_complex_s< T >::type to_complex
Obtain the type which adds the complex of complex/scalar type or the template parameter of class by a...
Definition math.hpp:373
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
A type representing the dimensions of a multidimensional batch object.
Definition batch_dim.hpp:56
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition batch_dim.hpp:72
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition batch_dim.hpp:65
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:55