40#ifndef TPETRA_BLOCKMULTIVECTOR_DEF_HPP
41#define TPETRA_BLOCKMULTIVECTOR_DEF_HPP
45#include "Teuchos_OrdinalTraits.hpp"
50template<
class Scalar,
class LO,
class GO,
class Node>
58template<
class Scalar,
class LO,
class GO,
class Node>
59Teuchos::RCP<const BlockMultiVector<Scalar, LO, GO, Node> >
64 const BMV* src_bmv =
dynamic_cast<const BMV*
> (&src);
65 TEUCHOS_TEST_FOR_EXCEPTION(
66 src_bmv ==
nullptr, std::invalid_argument,
"Tpetra::"
67 "BlockMultiVector: The source object of an Import or Export to a "
68 "BlockMultiVector, must also be a BlockMultiVector.");
69 return Teuchos::rcp (src_bmv,
false);
72template<
class Scalar,
class LO,
class GO,
class Node>
75 const Teuchos::DataAccess copyOrView) :
77 meshMap_ (in.meshMap_),
78 pointMap_ (in.pointMap_),
79 mv_ (in.mv_, copyOrView),
80 blockSize_ (in.blockSize_)
83template<
class Scalar,
class LO,
class GO,
class Node>
90 pointMap_ (makePointMap (meshMap, blockSize)),
91 mv_ (Teuchos::rcpFromRef (pointMap_), numVecs),
92 blockSize_ (blockSize)
95template<
class Scalar,
class LO,
class GO,
class Node>
103 pointMap_ (pointMap),
104 mv_ (Teuchos::rcpFromRef (pointMap_), numVecs),
105 blockSize_ (blockSize)
108template<
class Scalar,
class LO,
class GO,
class Node>
112 const LO blockSize) :
115 blockSize_ (blockSize)
131 RCP<const mv_type> X_view_const;
134 Teuchos::Array<size_t> cols (0);
135 X_view_const = X_mv.
subView (cols ());
137 X_view_const = X_mv.
subView (Teuchos::Range1D (0, numCols-1));
139 TEUCHOS_TEST_FOR_EXCEPTION(
140 X_view_const.is_null (), std::logic_error,
"Tpetra::"
141 "BlockMultiVector constructor: X_mv.subView(...) returned null. This "
142 "should never happen. Please report this bug to the Tpetra developers.");
147 RCP<mv_type> X_view = Teuchos::rcp_const_cast<mv_type> (X_view_const);
148 TEUCHOS_TEST_FOR_EXCEPTION(
149 X_view->getCopyOrView () != Teuchos::View, std::logic_error,
"Tpetra::"
150 "BlockMultiVector constructor: We just set a MultiVector "
151 "to have view semantics, but it claims that it doesn't have view "
152 "semantics. This should never happen. "
153 "Please report this bug to the Tpetra developers.");
158 Teuchos::RCP<const map_type> pointMap =
mv_.
getMap ();
159 if (! pointMap.is_null ()) {
160 pointMap_ = *pointMap;
164template<
class Scalar,
class LO,
class GO,
class Node>
169 const size_t offset) :
171 meshMap_ (newMeshMap),
172 pointMap_ (newPointMap),
173 mv_ (X.mv_, newPointMap, offset * X.getBlockSize ()),
174 blockSize_ (X.getBlockSize ())
177template<
class Scalar,
class LO,
class GO,
class Node>
181 const size_t offset) :
183 meshMap_ (newMeshMap),
184 pointMap_ (makePointMap (newMeshMap, X.getBlockSize ())),
185 mv_ (X.mv_, pointMap_, offset * X.getBlockSize ()),
186 blockSize_ (X.getBlockSize ())
189template<
class Scalar,
class LO,
class GO,
class Node>
196template<
class Scalar,
class LO,
class GO,
class Node>
202 typedef typename Teuchos::ArrayView<const GO>::size_type size_type;
204 const GST gblNumMeshMapInds =
206 const size_t lclNumMeshMapIndices =
208 const GST gblNumPointMapInds =
209 gblNumMeshMapInds *
static_cast<GST
> (blockSize);
210 const size_t lclNumPointMapInds =
211 lclNumMeshMapIndices *
static_cast<size_t> (blockSize);
215 return map_type (gblNumPointMapInds, lclNumPointMapInds, indexBase,
223 const size_type lclNumMeshGblInds = lclMeshGblInds.size ();
224 Teuchos::Array<GO> lclPointGblInds (lclNumPointMapInds);
225 for (size_type g = 0; g < lclNumMeshGblInds; ++g) {
226 const GO meshGid = lclMeshGblInds[g];
227 const GO pointGidStart = indexBase +
228 (meshGid - indexBase) *
static_cast<GO
> (blockSize);
229 const size_type offset = g *
static_cast<size_type
> (blockSize);
230 for (LO k = 0; k < blockSize; ++k) {
231 const GO pointGid = pointGidStart +
static_cast<GO
> (k);
232 lclPointGblInds[offset +
static_cast<size_type
> (k)] = pointGid;
235 return map_type (gblNumPointMapInds, lclPointGblInds (), indexBase,
241template<
class Scalar,
class LO,
class GO,
class Node>
248 auto X_dst = getLocalBlockHost (localRowIndex, colIndex, Access::ReadWrite);
249 typename const_little_vec_type::HostMirror::const_type X_src (
reinterpret_cast<const impl_scalar_type*
> (vals),
252 using exec_space =
typename device_type::execution_space;
253 Kokkos::deep_copy (exec_space(), X_dst, X_src);
257template<
class Scalar,
class LO,
class GO,
class Node>
264 if (! meshMap_.isNodeLocalElement (localRowIndex)) {
267 replaceLocalValuesImpl (localRowIndex, colIndex, vals);
272template<
class Scalar,
class LO,
class GO,
class Node>
279 const LO localRowIndex = meshMap_.getLocalElement (globalRowIndex);
280 if (localRowIndex == Teuchos::OrdinalTraits<LO>::invalid ()) {
283 replaceLocalValuesImpl (localRowIndex, colIndex, vals);
288template<
class Scalar,
class LO,
class GO,
class Node>
295 auto X_dst = getLocalBlockHost (localRowIndex, colIndex, Access::ReadWrite);
296 typename const_little_vec_type::HostMirror::const_type X_src (
reinterpret_cast<const impl_scalar_type*
> (vals),
298 AXPY (
static_cast<impl_scalar_type
> (STS::one ()), X_src, X_dst);
301template<
class Scalar,
class LO,
class GO,
class Node>
308 if (! meshMap_.isNodeLocalElement (localRowIndex)) {
311 sumIntoLocalValuesImpl (localRowIndex, colIndex, vals);
316template<
class Scalar,
class LO,
class GO,
class Node>
323 const LO localRowIndex = meshMap_.getLocalElement (globalRowIndex);
324 if (localRowIndex == Teuchos::OrdinalTraits<LO>::invalid ()) {
327 sumIntoLocalValuesImpl (localRowIndex, colIndex, vals);
333template<
class Scalar,
class LO,
class GO,
class Node>
334typename BlockMultiVector<Scalar, LO, GO, Node>::const_little_host_vec_type
338 const Access::ReadOnlyStruct)
const
340 if (!isValidLocalMeshIndex(localRowIndex)) {
341 return const_little_host_vec_type();
343 const size_t blockSize = getBlockSize();
344 auto hostView = mv_.getLocalViewHost(Access::ReadOnly);
345 LO startRow = localRowIndex*blockSize;
346 LO endRow = startRow + blockSize;
347 return Kokkos::subview(hostView, Kokkos::make_pair(startRow, endRow),
352template<
class Scalar,
class LO,
class GO,
class Node>
353typename BlockMultiVector<Scalar, LO, GO, Node>::little_host_vec_type
357 const Access::OverwriteAllStruct)
359 if (!isValidLocalMeshIndex(localRowIndex)) {
360 return little_host_vec_type();
362 const size_t blockSize = getBlockSize();
363 auto hostView = mv_.getLocalViewHost(Access::OverwriteAll);
364 LO startRow = localRowIndex*blockSize;
365 LO endRow = startRow + blockSize;
366 return Kokkos::subview(hostView, Kokkos::make_pair(startRow, endRow),
371template<
class Scalar,
class LO,
class GO,
class Node>
372typename BlockMultiVector<Scalar, LO, GO, Node>::little_host_vec_type
376 const Access::ReadWriteStruct)
378 if (!isValidLocalMeshIndex(localRowIndex)) {
379 return little_host_vec_type();
381 const size_t blockSize = getBlockSize();
382 auto hostView = mv_.getLocalViewHost(Access::ReadWrite);
383 LO startRow = localRowIndex*blockSize;
384 LO endRow = startRow + blockSize;
385 return Kokkos::subview(hostView, Kokkos::make_pair(startRow, endRow),
390template<
class Scalar,
class LO,
class GO,
class Node>
391Teuchos::RCP<const typename BlockMultiVector<Scalar, LO, GO, Node>::mv_type>
392BlockMultiVector<Scalar, LO, GO, Node>::
396 using Teuchos::rcpFromRef;
402 typedef BlockMultiVector<Scalar, LO, GO, Node> this_BMV_type;
403 const this_BMV_type* srcBlkVec =
dynamic_cast<const this_BMV_type*
> (&src);
404 if (srcBlkVec ==
nullptr) {
405 const mv_type* srcMultiVec =
dynamic_cast<const mv_type*
> (&src);
406 if (srcMultiVec ==
nullptr) {
410 return rcp (
new mv_type ());
412 return rcp (srcMultiVec,
false);
415 return rcpFromRef (srcBlkVec->mv_);
419template<
class Scalar,
class LO,
class GO,
class Node>
423 return ! getMultiVectorFromSrcDistObject (src).is_null ();
426template<
class Scalar,
class LO,
class GO,
class Node>
430 const size_t numSameIDs,
431 const Kokkos::DualView<
const local_ordinal_type*,
432 buffer_device_type>& permuteToLIDs,
433 const Kokkos::DualView<
const local_ordinal_type*,
434 buffer_device_type>& permuteFromLIDs,
437 TEUCHOS_TEST_FOR_EXCEPTION
438 (
true, std::logic_error,
439 "Tpetra::BlockMultiVector::copyAndPermute: Do NOT use this "
440 "instead, create a point importer using makePointMap function.");
443template<
class Scalar,
class LO,
class GO,
class Node>
444void BlockMultiVector<Scalar, LO, GO, Node>::
446(
const SrcDistObject& src,
447 const Kokkos::DualView<
const local_ordinal_type*,
448 buffer_device_type>& exportLIDs,
449 Kokkos::DualView<packet_type*,
450 buffer_device_type>& exports,
451 Kokkos::DualView<
size_t*,
452 buffer_device_type> numPacketsPerLID,
453 size_t& constantNumPackets)
455 TEUCHOS_TEST_FOR_EXCEPTION
456 (
true, std::logic_error,
457 "Tpetra::BlockMultiVector::copyAndPermute: Do NOT use this; "
458 "instead, create a point importer using makePointMap function.");
461template<
class Scalar,
class LO,
class GO,
class Node>
462void BlockMultiVector<Scalar, LO, GO, Node>::
464(
const Kokkos::DualView<
const local_ordinal_type*,
465 buffer_device_type>& importLIDs,
466 Kokkos::DualView<packet_type*,
467 buffer_device_type> imports,
468 Kokkos::DualView<
size_t*,
469 buffer_device_type> numPacketsPerLID,
470 const size_t constantNumPackets,
473 TEUCHOS_TEST_FOR_EXCEPTION
474 (
true, std::logic_error,
475 "Tpetra::BlockMultiVector::copyAndPermute: Do NOT use this; "
476 "instead, create a point importer using makePointMap function.");
479template<
class Scalar,
class LO,
class GO,
class Node>
483 return meshLocalIndex != Teuchos::OrdinalTraits<LO>::invalid () &&
484 meshMap_.isNodeLocalElement (meshLocalIndex);
487template<
class Scalar,
class LO,
class GO,
class Node>
494template<
class Scalar,
class LO,
class GO,
class Node>
496scale (
const Scalar& val)
501template<
class Scalar,
class LO,
class GO,
class Node>
503update (
const Scalar& alpha,
507 mv_.update (alpha, X.
mv_, beta);
512template <
typename Scalar,
typename ViewY,
typename ViewD,
typename ViewX>
513struct BlockWiseMultiply {
514 typedef typename ViewD::size_type Size;
517 typedef typename ViewD::device_type Device;
518 typedef typename ViewD::non_const_value_type ImplScalar;
519 typedef Kokkos::MemoryTraits<Kokkos::Unmanaged> Unmanaged;
521 template <
typename View>
522 using UnmanagedView = Kokkos::View<
typename View::data_type,
typename View::array_layout,
523 typename View::device_type, Unmanaged>;
524 typedef UnmanagedView<ViewY> UnMViewY;
525 typedef UnmanagedView<ViewD> UnMViewD;
526 typedef UnmanagedView<ViewX> UnMViewX;
528 const Size block_size_;
535 BlockWiseMultiply (
const Size block_size,
const Scalar& alpha,
536 const ViewY& Y,
const ViewD& D,
const ViewX& X)
537 : block_size_(block_size), alpha_(alpha), Y_(Y), D_(D), X_(X)
540 KOKKOS_INLINE_FUNCTION
541 void operator() (
const Size k)
const {
542 const auto zero = Kokkos::Details::ArithTraits<Scalar>::zero();
543 auto D_curBlk = Kokkos::subview(D_, k, Kokkos::ALL (), Kokkos::ALL ());
544 const auto num_vecs = X_.extent(1);
545 for (Size i = 0; i < num_vecs; ++i) {
546 Kokkos::pair<Size, Size> kslice(k*block_size_, (k+1)*block_size_);
547 auto X_curBlk = Kokkos::subview(X_, kslice, i);
548 auto Y_curBlk = Kokkos::subview(Y_, kslice, i);
557template <
typename Scalar,
typename ViewY,
typename ViewD,
typename ViewX>
558inline BlockWiseMultiply<Scalar, ViewY, ViewD, ViewX>
559createBlockWiseMultiply (
const int block_size,
const Scalar& alpha,
560 const ViewY& Y,
const ViewD& D,
const ViewX& X) {
561 return BlockWiseMultiply<Scalar, ViewY, ViewD, ViewX>(block_size, alpha, Y, D, X);
564template <
typename ViewY,
568 typename LO =
typename ViewY::size_type>
569class BlockJacobiUpdate {
571 typedef typename ViewD::device_type Device;
572 typedef typename ViewD::non_const_value_type ImplScalar;
573 typedef Kokkos::MemoryTraits<Kokkos::Unmanaged> Unmanaged;
575 template <
typename ViewType>
576 using UnmanagedView = Kokkos::View<
typename ViewType::data_type,
577 typename ViewType::array_layout,
578 typename ViewType::device_type,
580 typedef UnmanagedView<ViewY> UnMViewY;
581 typedef UnmanagedView<ViewD> UnMViewD;
582 typedef UnmanagedView<ViewZ> UnMViewZ;
592 BlockJacobiUpdate (
const ViewY& Y,
596 const Scalar& beta) :
597 blockSize_ (D.extent (1)),
605 static_assert (
static_cast<int> (ViewY::rank) == 1,
606 "Y must have rank 1.");
607 static_assert (
static_cast<int> (ViewD::rank) == 3,
"D must have rank 3.");
608 static_assert (
static_cast<int> (ViewZ::rank) == 1,
609 "Z must have rank 1.");
615 KOKKOS_INLINE_FUNCTION
void
616 operator() (
const LO& k)
const
619 using Kokkos::subview;
620 typedef Kokkos::pair<LO, LO> range_type;
621 typedef Kokkos::Details::ArithTraits<Scalar> KAT;
625 auto D_curBlk = subview (D_, k, ALL (), ALL ());
626 const range_type kslice (k*blockSize_, (k+1)*blockSize_);
630 auto Z_curBlk = subview (Z_, kslice);
631 auto Y_curBlk = subview (Y_, kslice);
633 if (beta_ == KAT::zero ()) {
636 else if (beta_ != KAT::one ()) {
647 class LO =
typename ViewD::size_type>
649blockJacobiUpdate (
const ViewY& Y,
655 static_assert (Kokkos::is_view<ViewY>::value,
"Y must be a Kokkos::View.");
656 static_assert (Kokkos::is_view<ViewD>::value,
"D must be a Kokkos::View.");
657 static_assert (Kokkos::is_view<ViewZ>::value,
"Z must be a Kokkos::View.");
658 static_assert (
static_cast<int> (ViewY::rank) ==
static_cast<int> (ViewZ::rank),
659 "Y and Z must have the same rank.");
660 static_assert (
static_cast<int> (ViewD::rank) == 3,
"D must have rank 3.");
662 const auto lclNumMeshRows = D.extent (0);
664#ifdef HAVE_TPETRA_DEBUG
668 const auto blkSize = D.extent (1);
669 const auto lclNumPtRows = lclNumMeshRows * blkSize;
670 TEUCHOS_TEST_FOR_EXCEPTION
671 (Y.extent (0) != lclNumPtRows, std::invalid_argument,
672 "blockJacobiUpdate: Y.extent(0) = " << Y.extent (0) <<
" != "
673 "D.extent(0)*D.extent(1) = " << lclNumMeshRows <<
" * " << blkSize
674 <<
" = " << lclNumPtRows <<
".");
675 TEUCHOS_TEST_FOR_EXCEPTION
676 (Y.extent (0) != Z.extent (0), std::invalid_argument,
677 "blockJacobiUpdate: Y.extent(0) = " << Y.extent (0) <<
" != "
678 "Z.extent(0) = " << Z.extent (0) <<
".");
679 TEUCHOS_TEST_FOR_EXCEPTION
680 (Y.extent (1) != Z.extent (1), std::invalid_argument,
681 "blockJacobiUpdate: Y.extent(1) = " << Y.extent (1) <<
" != "
682 "Z.extent(1) = " << Z.extent (1) <<
".");
685 BlockJacobiUpdate<ViewY, Scalar, ViewD, ViewZ, LO> functor (Y, alpha, D, Z, beta);
686 typedef Kokkos::RangePolicy<typename ViewY::execution_space, LO> range_type;
688 range_type range (0,
static_cast<LO
> (lclNumMeshRows));
689 Kokkos::parallel_for (range, functor);
694template<
class Scalar,
class LO,
class GO,
class Node>
702 typedef typename device_type::execution_space exec_space;
703 const LO lclNumMeshRows = meshMap_.getLocalNumElements ();
705 if (alpha == STS::zero ()) {
706 this->putScalar (STS::zero ());
709 const LO blockSize = this->getBlockSize ();
712 auto Y_lcl = this->mv_.getLocalViewDevice (Access::ReadWrite);
713 auto bwm = Impl::createBlockWiseMultiply (blockSize, alphaImpl, Y_lcl, D, X_lcl);
720 Kokkos::RangePolicy<exec_space, LO> range (0, lclNumMeshRows);
721 Kokkos::parallel_for (range, bwm);
725template<
class Scalar,
class LO,
class GO,
class Node>
735 using Kokkos::subview;
738 const IST alphaImpl =
static_cast<IST
> (alpha);
739 const IST betaImpl =
static_cast<IST
> (beta);
740 const LO numVecs = mv_.getNumVectors ();
742 if (alpha == STS::zero ()) {
746 Z.
update (STS::one (), X, -STS::one ());
747 for (LO j = 0; j < numVecs; ++j) {
748 auto Y_lcl = this->mv_.getLocalViewDevice (Access::ReadWrite);
750 auto Y_lcl_j = subview (Y_lcl, ALL (), j);
751 auto Z_lcl_j = subview (Z_lcl, ALL (), j);
752 Impl::blockJacobiUpdate (Y_lcl_j, alphaImpl, D, Z_lcl_j, betaImpl);
764#define TPETRA_BLOCKMULTIVECTOR_INSTANT(S,LO,GO,NODE) \
765 template class BlockMultiVector< S, LO, GO, NODE >;
Linear algebra kernels for small dense matrices and vectors.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
MultiVector for multiple degrees of freedom per mesh point.
virtual bool checkSizes(const Tpetra::SrcDistObject &source)
Compare the source and target (this) objects for compatibility.
void putScalar(const Scalar &val)
Fill all entries with the given value val.
void blockWiseMultiply(const Scalar &alpha, const Kokkos::View< const impl_scalar_type ***, device_type, Kokkos::MemoryUnmanaged > &D, const BlockMultiVector< Scalar, LO, GO, Node > &X)
*this := alpha * D * X, where D is a block diagonal matrix.
bool sumIntoLocalValues(const LO localRowIndex, const LO colIndex, const Scalar vals[])
Sum into all values at the given mesh point, using a local index.
void blockJacobiUpdate(const Scalar &alpha, const Kokkos::View< const impl_scalar_type ***, device_type, Kokkos::MemoryUnmanaged > &D, const BlockMultiVector< Scalar, LO, GO, Node > &X, BlockMultiVector< Scalar, LO, GO, Node > &Z, const Scalar &beta)
Block Jacobi update .
void update(const Scalar &alpha, const BlockMultiVector< Scalar, LO, GO, Node > &X, const Scalar &beta)
Update: this = beta*this + alpha*X.
typename mv_type::impl_scalar_type impl_scalar_type
The implementation type of entries in the object.
bool sumIntoGlobalValues(const GO globalRowIndex, const LO colIndex, const Scalar vals[])
Sum into all values at the given mesh point, using a global index.
BlockMultiVector()
Default constructor.
mv_type getMultiVectorView() const
Get a Tpetra::MultiVector that views this BlockMultiVector's data.
bool replaceLocalValues(const LO localRowIndex, const LO colIndex, const Scalar vals[])
Replace all values at the given mesh point, using local row and column indices.
static map_type makePointMap(const map_type &meshMap, const LO blockSize)
Create and return the point Map corresponding to the given mesh Map and block size.
Tpetra::MultiVector< Scalar, LO, GO, Node > mv_type
The specialization of Tpetra::MultiVector that this class uses.
typename mv_type::device_type device_type
The Kokkos Device type.
void scale(const Scalar &val)
Multiply all entries in place by the given value val.
bool replaceGlobalValues(const GO globalRowIndex, const LO colIndex, const Scalar vals[])
Replace all values at the given mesh point, using a global index.
bool isValidLocalMeshIndex(const LO meshLocalIndex) const
True if and only if meshLocalIndex is a valid local index in the mesh Map.
mv_type mv_
The Tpetra::MultiVector used to represent the data.
virtual Teuchos::RCP< const map_type > getMap() const
The Map describing the parallel distribution of this object.
Teuchos::ArrayView< const global_ordinal_type > getLocalElementList() const
Return a NONOWNING view of the global indices owned by this process.
Teuchos::RCP< const Teuchos::Comm< int > > getComm() const
Accessors for the Teuchos::Comm and Kokkos Node objects.
global_ordinal_type getIndexBase() const
The index base for this Map.
global_size_t getGlobalNumElements() const
The number of elements in this Map.
bool isContiguous() const
True if this Map is distributed contiguously, else false.
size_t getLocalNumElements() const
The number of elements belonging to the calling process.
Teuchos::RCP< const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > subView(const Teuchos::Range1D &colRng) const
Return a const MultiVector with const views of selected columns.
size_t getNumVectors() const
Number of columns in the multivector.
Teuchos::DataAccess getCopyOrView() const
Get whether this has copy (copyOrView = Teuchos::Copy) or view (copyOrView = Teuchos::View) semantics...
dual_view_type::t_dev::const_type getLocalViewDevice(Access::ReadOnlyStruct) const
Return a read-only, up-to-date view of this MultiVector's local data on device. This requires that th...
Abstract base class for objects that can be the source of an Import or Export operation.
Namespace Tpetra contains the class and methods constituting the Tpetra library.
KOKKOS_INLINE_FUNCTION void GEMV(const CoeffType &alpha, const BlkType &A, const VecType1 &x, const VecType2 &y)
y := y + alpha * A * x (dense matrix-vector multiply)
KOKKOS_INLINE_FUNCTION void AXPY(const CoefficientType &alpha, const ViewType1 &x, const ViewType2 &y)
y := y + alpha * x (dense vector or matrix update)
KOKKOS_INLINE_FUNCTION void FILL(const ViewType &x, const InputType &val)
Set every entry of x to val.
size_t global_size_t
Global size_t object.
KOKKOS_INLINE_FUNCTION void SCAL(const CoefficientType &alpha, const ViewType &x)
x := alpha*x, where x is either rank 1 (a vector) or rank 2 (a matrix).
CombineMode
Rule for combining data in an Import or Export.