40 #ifndef TPETRA_TSQRADAPTOR_HPP
41 #define TPETRA_TSQRADAPTOR_HPP
47 #include "Tpetra_ConfigDefs.hpp"
49 #ifdef HAVE_TPETRA_TSQR
50 # include "Tsqr_NodeTsqrFactory.hpp"
52 # include "Tsqr_DistTsqr.hpp"
55 # include "Tsqr_TeuchosMessenger.hpp"
56 # include "Tpetra_MultiVector.hpp"
57 # include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
84 class TsqrAdaptor :
public Teuchos::ParameterListAcceptorDefaultBase {
86 using scalar_type =
typename MV::scalar_type;
87 using ordinal_type =
typename MV::local_ordinal_type;
88 using dense_matrix_type =
89 Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
90 using magnitude_type =
91 typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
94 using node_tsqr_factory_type =
95 TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
96 typename MV::device_type>;
97 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
98 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
99 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
101 TSQR::MatView<ordinal_type, scalar_type>
104 TEUCHOS_ASSERT( ! tsqr_.is_null() );
109 const ordinal_type lclNumRows(X.getLocalLength());
110 const ordinal_type numCols(X.getNumVectors());
111 scalar_type* X_ptr =
nullptr;
114 ordinal_type X_stride = 1;
115 if(tsqr_->wants_device_memory()) {
116 auto X_view = X.getLocalViewDevice(Access::ReadWrite);
117 X_ptr =
reinterpret_cast<scalar_type*
>(X_view.data());
118 X_stride =
static_cast<ordinal_type
>(X_view.stride(1));
120 X_stride = ordinal_type(1);
124 auto X_view = X.getLocalViewHost(Access::ReadWrite);
125 X_ptr =
reinterpret_cast<scalar_type*
>(X_view.data());
126 X_stride =
static_cast<ordinal_type
>(X_view.stride(1));
128 X_stride = ordinal_type(1);
131 using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
132 return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
142 TsqrAdaptor(
const Teuchos::RCP<Teuchos::ParameterList>& plist) :
143 nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
144 distTsqr_(new dist_tsqr_type),
145 tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
147 setParameterList(plist);
152 nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
153 distTsqr_(new dist_tsqr_type),
154 tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
156 setParameterList(Teuchos::null);
160 Teuchos::RCP<const Teuchos::ParameterList>
161 getValidParameters()
const
163 if(defaultParams_.is_null()) {
164 auto params = Teuchos::parameterList(
"TSQR implementation");
165 params->set(
"NodeTsqr", *(nodeTsqr_->getValidParameters()));
166 params->set(
"DistTsqr", *(distTsqr_->getValidParameters()));
167 defaultParams_ = params;
169 return defaultParams_;
198 setParameterList(
const Teuchos::RCP<Teuchos::ParameterList>& plist)
200 auto params = plist.is_null() ?
201 Teuchos::parameterList(*getValidParameters()) : plist;
202 using Teuchos::sublist;
203 nodeTsqr_->setParameterList(sublist(params,
"NodeTsqr"));
204 distTsqr_->setParameterList(sublist(params,
"DistTsqr"));
206 this->setMyParamList(params);
231 factorExplicit(MV& A,
233 dense_matrix_type& R,
234 const bool forceNonnegativeDiagonal=
false)
236 TEUCHOS_TEST_FOR_EXCEPTION
237 (! A.isConstantStride(), std::invalid_argument,
"TsqrAdaptor::"
238 "factorExplicit: Input MultiVector A must have constant stride.");
239 TEUCHOS_TEST_FOR_EXCEPTION
240 (! Q.isConstantStride(), std::invalid_argument,
"TsqrAdaptor::"
241 "factorExplicit: Input MultiVector Q must have constant stride.");
243 TEUCHOS_ASSERT( ! tsqr_.is_null() );
245 auto A_view = get_mat_view(A);
246 auto Q_view = get_mat_view(Q);
247 constexpr
bool contiguousCacheBlocks =
false;
248 tsqr_->factorExplicitRaw(A_view.extent(0),
250 A_view.data(), A_view.stride(1),
251 Q_view.data(), Q_view.stride(1),
252 R.values(), R.stride(),
253 contiguousCacheBlocks,
254 forceNonnegativeDiagonal);
289 dense_matrix_type& R,
290 const magnitude_type& tol)
292 TEUCHOS_TEST_FOR_EXCEPTION
293 (! Q.isConstantStride(), std::invalid_argument,
"TsqrAdaptor::"
294 "revealRank: Input MultiVector Q must have constant stride.");
297 auto Q_view = get_mat_view(Q);
298 constexpr
bool contiguousCacheBlocks =
false;
299 return tsqr_->revealRankRaw(Q_view.extent(0),
301 Q_view.data(), Q_view.stride(1),
302 R.values(), R.stride(),
303 tol, contiguousCacheBlocks);
308 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
311 Teuchos::RCP<dist_tsqr_type> distTsqr_;
314 Teuchos::RCP<tsqr_type> tsqr_;
317 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
343 prepareTsqr(
const MV& mv)
358 prepareDistTsqr(
const MV& mv)
361 using Teuchos::rcp_implicit_cast;
362 using mess_type = TSQR::TeuchosMessenger<scalar_type>;
363 using base_mess_type = TSQR::MessengerBase<scalar_type>;
365 auto comm = mv.getMap()->getComm();
366 RCP<mess_type> mess(
new mess_type(comm));
367 auto messBase = rcp_implicit_cast<base_mess_type>(mess);
368 distTsqr_->init(messBase);
374 #endif // HAVE_TPETRA_TSQR
376 #endif // TPETRA_TSQRADAPTOR_HPP