30 #ifndef SACADO_FAD_EXP_ATOMIC_HPP
31 #define SACADO_FAD_EXP_ATOMIC_HPP
34 #if defined(HAVE_SACADO_KOKKOS)
37 #include "Kokkos_Atomic.hpp"
38 #include "impl/Kokkos_Error.hpp"
46 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
48 void atomic_add(ViewFadPtr<ValT,sl,ss,U> dst,
const Expr<T>& xx) {
49 using Kokkos::atomic_add;
51 const typename Expr<T>::derived_type&
x = xx.derived();
53 const int xsz = x.size();
54 const int sz = dst->size();
60 "Sacado error: Fad resize within atomic_add() not supported!");
62 if (xsz != sz && sz > 0 && xsz > 0)
64 "Sacado error: Fad assignment of incompatiable sizes!");
67 if (sz > 0 && xsz > 0) {
72 atomic_add(&(dst->
val()), x.val());
78 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
80 atomic_oper_fetch_host(
const Oper& op, DestPtrT dest, ValT* dest_val,
84 const typename Expr<T>::derived_type& val = x.derived();
86 #ifdef KOKKOS_INTERNAL_NOT_PARALLEL
87 auto scope = desul::MemoryScopeCaller();
89 auto scope = desul::MemoryScopeDevice();
92 while (!desul::Impl::lock_address((
void*)dest_val, scope))
94 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
95 return_type return_val = op.apply(*dest, val);
97 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
98 desul::Impl::unlock_address((
void*)dest_val, scope);
102 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
104 atomic_fetch_oper_host(
const Oper& op, DestPtrT dest, ValT* dest_val,
108 const typename Expr<T>::derived_type& val = x.derived();
110 #ifdef KOKKOS_INTERNAL_NOT_PARALLEL
111 auto scope = desul::MemoryScopeCaller();
113 auto scope = desul::MemoryScopeDevice();
116 while (!desul::Impl::lock_address((
void*)dest_val, scope))
118 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
119 return_type return_val = *dest;
120 *dest = op.apply(return_val, val);
121 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
122 desul::Impl::unlock_address((
void*)dest_val, scope);
127 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
129 inline bool atomics_use_team() {
130 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD)
135 return (blockDim.x > 1);
142 #if defined(KOKKOS_ENABLE_CUDA)
146 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
149 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
153 const typename Expr<T>::derived_type& val = x.derived();
155 auto scope = desul::MemoryScopeDevice();
157 if (atomics_use_team()) {
160 if (threadIdx.x == 0)
161 go = !desul::Impl::lock_address_cuda((
void*)dest_val, scope);
162 go = Kokkos::shfl(go, 0, blockDim.x);
164 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
165 return_type return_val = op.apply(*dest, val);
167 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
168 if (threadIdx.x == 0)
169 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
173 return_type return_val;
176 unsigned int mask = __activemask() ;
177 unsigned int active = __ballot_sync(mask, 1);
178 unsigned int done_active = 0;
179 while (active != done_active) {
181 if (desul::Impl::lock_address_cuda((
void*)dest_val, scope)) {
182 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
183 return_val = op.apply(*dest, val);
185 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
186 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
190 done_active = __ballot_sync(mask, done);
196 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
199 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
203 const typename Expr<T>::derived_type& val = x.derived();
205 auto scope = desul::MemoryScopeDevice();
207 if (atomics_use_team()) {
210 if (threadIdx.x == 0)
211 go = !desul::Impl::lock_address_cuda((
void*)dest_val, scope);
212 go = Kokkos::shfl(go, 0, blockDim.x);
214 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
215 return_type return_val = *dest;
216 *dest = op.apply(return_val, val);
217 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
218 if (threadIdx.x == 0)
219 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
223 return_type return_val;
226 unsigned int mask = __activemask() ;
227 unsigned int active = __ballot_sync(mask, 1);
228 unsigned int done_active = 0;
229 while (active != done_active) {
231 if (desul::Impl::lock_address_cuda((
void*)dest_val, scope)) {
232 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
234 *dest = op.apply(return_val, val);
235 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
236 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
240 done_active = __ballot_sync(mask, done);
246 #elif defined(KOKKOS_ENABLE_HIP)
250 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
253 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
257 const typename Expr<T>::derived_type& val = x.derived();
259 auto scope = desul::MemoryScopeDevice();
261 if (atomics_use_team()) {
264 if (threadIdx.x == 0)
265 go = !desul::Impl::lock_address_hip((
void*)dest_val, scope);
266 go = Kokkos::shfl(go, 0, blockDim.x);
268 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
269 return_type return_val = op.apply(*dest, val);
271 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
272 if (threadIdx.x == 0)
273 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
277 return_type return_val;
279 unsigned int active = __ballot(1);
280 unsigned int done_active = 0;
281 while (active != done_active) {
283 if (desul::Impl::lock_address_hip((
void*)dest_val, scope)) {
284 return_val = op.apply(*dest, val);
286 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
290 done_active = __ballot(done);
296 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
299 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
303 const typename Expr<T>::derived_type& val = x.derived();
305 auto scope = desul::MemoryScopeDevice();
307 if (atomics_use_team()) {
310 if (threadIdx.x == 0)
311 go = !desul::Impl::lock_address_hip((
void*)dest_val, scope);
312 go = Kokkos::shfl(go, 0, blockDim.x);
314 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
315 return_type return_val = *dest;
316 *dest = op.apply(return_val, val);
317 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
318 if (threadIdx.x == 0)
319 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
323 return_type return_val;
325 unsigned int active = __ballot(1);
326 unsigned int done_active = 0;
327 while (active != done_active) {
329 if (desul::Impl::lock_address_hip((
void*)dest_val, scope)) {
331 *dest = op.apply(return_val, val);
332 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
336 done_active = __ballot(done);
342 #elif defined(KOKKOS_ENABLE_SYCL)
346 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
348 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
351 Kokkos::abort(
"Not implemented!");
355 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
357 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
360 Kokkos::abort(
"Not implemented!");
367 template <
typename Oper,
typename S>
369 atomic_oper_fetch(
const Oper& op, GeneralFad<S>* dest,
370 const GeneralFad<S>& val)
372 KOKKOS_IF_ON_HOST(
return Impl::atomic_oper_fetch_host(op, dest, &(dest->val()), val);)
373 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_oper_fetch_device(op, dest, &(dest->val()), val);)
375 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
376 typename U,
typename T>
378 atomic_oper_fetch(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
381 KOKKOS_IF_ON_HOST(
return Impl::atomic_oper_fetch_host(op, dest, &dest.val(),
val);)
382 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_oper_fetch_device(op, dest, &dest.val(),
val);)
385 template <typename Oper, typename S>
387 atomic_fetch_oper(
const Oper& op, GeneralFad<S>* dest,
388 const GeneralFad<S>& val)
390 KOKKOS_IF_ON_HOST(
return Impl::atomic_fetch_oper_host(op, dest, &(dest->val()), val);)
391 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_fetch_oper_device(op, dest, &(dest->val()), val);)
393 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
394 typename U,
typename T>
396 atomic_fetch_oper(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
399 KOKKOS_IF_ON_HOST(
return Impl::atomic_fetch_oper_host(op, dest, &dest.val(),
val);)
400 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_fetch_oper_device(op, dest, &dest.val(),
val);)
405 template <
class Scalar1,
class Scalar2>
406 KOKKOS_FORCEINLINE_FUNCTION
407 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
408 -> decltype(
max(val1,val2))
410 return max(val1,val2);
414 template <
class Scalar1,
class Scalar2>
415 KOKKOS_FORCEINLINE_FUNCTION
416 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
417 -> decltype(
min(val1,val2))
419 return min(val1,val2);
423 template <
class Scalar1,
class Scalar2>
424 KOKKOS_FORCEINLINE_FUNCTION
425 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
426 -> decltype(val1+val2)
432 template <
class Scalar1,
class Scalar2>
433 KOKKOS_FORCEINLINE_FUNCTION
434 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
435 -> decltype(val1-val2)
441 template <
class Scalar1,
class Scalar2>
442 KOKKOS_FORCEINLINE_FUNCTION
443 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
444 -> decltype(val1*val2)
450 template <
class Scalar1,
class Scalar2>
451 KOKKOS_FORCEINLINE_FUNCTION
452 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
453 -> decltype(val1/val2)
463 template <
typename S>
465 atomic_max_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
466 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest, val);
468 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
470 atomic_max_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
471 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest, val);
473 template <
typename S>
475 atomic_min_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
476 return Impl::atomic_oper_fetch(Impl::MinOper(), dest, val);
478 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
480 atomic_min_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
481 return Impl::atomic_oper_fetch(Impl::MinOper(), dest, val);
483 template <
typename S>
485 atomic_add_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
486 return Impl::atomic_oper_fetch(Impl::AddOper(), dest, val);
488 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
490 atomic_add_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
491 return Impl::atomic_oper_fetch(Impl::AddOper(), dest, val);
493 template <
typename S>
495 atomic_sub_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
496 return Impl::atomic_oper_fetch(Impl::SubOper(), dest, val);
498 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
500 atomic_sub_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
501 return Impl::atomic_oper_fetch(Impl::SubOper(), dest, val);
503 template <
typename S>
505 atomic_mul_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
506 return atomic_oper_fetch(Impl::MulOper(), dest, val);
508 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
510 atomic_mul_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
511 return Impl::atomic_oper_fetch(Impl::MulOper(), dest, val);
513 template <
typename S>
515 atomic_div_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
516 return Impl::atomic_oper_fetch(Impl::DivOper(), dest, val);
518 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
520 atomic_div_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
521 return Impl::atomic_oper_fetch(Impl::DivOper(), dest, val);
524 template <
typename S>
526 atomic_fetch_max(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
527 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest, val);
529 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
531 atomic_fetch_max(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
532 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest, val);
534 template <
typename S>
536 atomic_fetch_min(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
537 return Impl::atomic_fetch_oper(Impl::MinOper(), dest, val);
539 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
541 atomic_fetch_min(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
542 return Impl::atomic_fetch_oper(Impl::MinOper(), dest, val);
544 template <
typename S>
546 atomic_fetch_add(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
547 return Impl::atomic_fetch_oper(Impl::AddOper(), dest, val);
549 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
551 atomic_fetch_add(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
552 return Impl::atomic_fetch_oper(Impl::AddOper(), dest, val);
554 template <
typename S>
556 atomic_fetch_sub(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
557 return Impl::atomic_fetch_oper(Impl::SubOper(), dest, val);
559 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
561 atomic_fetch_sub(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
562 return Impl::atomic_fetch_oper(Impl::SubOper(), dest, val);
564 template <
typename S>
566 atomic_fetch_mul(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
567 return Impl::atomic_fetch_oper(Impl::MulOper(), dest, val);
569 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
571 atomic_fetch_mul(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
572 return Impl::atomic_fetch_oper(Impl::MulOper(), dest, val);
574 template <
typename S>
576 atomic_fetch_div(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
577 return Impl::atomic_fetch_oper(Impl::DivOper(), dest, val);
579 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
581 atomic_fetch_div(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
582 return Impl::atomic_fetch_oper(Impl::DivOper(), dest, val);
590 #endif // HAVE_SACADO_KOKKOS
591 #endif // SACADO_FAD_EXP_VIEWFAD_HPP
#define SACADO_FAD_THREAD_SINGLE
SimpleFad< ValueT > min(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_FAD_DERIV_LOOP(I, SZ)
Get the base Fad type from a view/expression.
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_INLINE_FUNCTION