Amesos2 - Direct Sparse Solver Interfaces  Version of the Day
Amesos2_cuSOLVER_FunctionMap.hpp
1 // @HEADER
2 //
3 // ***********************************************************************
4 //
5 // Amesos2: Templated Direct Sparse Solver Package
6 // Copyright 2011 Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39 //
40 // ***********************************************************************
41 //
42 // @HEADER
43 
44 #ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
45 #define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
46 
47 #include "Amesos2_FunctionMap.hpp"
48 #include "Amesos2_cuSOLVER_TypeMap.hpp"
49 
50 #include <cuda.h>
51 #include <cusolverSp.h>
52 #include <cusolverDn.h>
53 #include <cusparse.h>
54 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
55 
56 #ifdef HAVE_TEUCHOS_COMPLEX
57 #include <cuComplex.h>
58 #endif
59 
60 namespace Amesos2 {
61 
62  template <>
63  struct FunctionMap<cuSOLVER,double>
64  {
65  static cusolverStatus_t bufferInfo(
66  cusolverSpHandle_t handle,
67  int size,
68  int nnz,
69  cusparseMatDescr_t & desc,
70  const double * values,
71  const int * rowPtr,
72  const int * colIdx,
73  csrcholInfo_t & chol_info,
74  size_t * internalDataInBytes,
75  size_t * workspaceInBytes)
76  {
77  cusolverStatus_t status =
78  cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
79  rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
80  return status;
81  }
82 
83  static cusolverStatus_t numeric(
84  cusolverSpHandle_t handle,
85  int size,
86  int nnz,
87  cusparseMatDescr_t & desc,
88  const double * values,
89  const int * rowPtr,
90  const int * colIdx,
91  csrcholInfo_t & chol_info,
92  void * buffer)
93  {
94  cusolverStatus_t status = cusolverSpDcsrcholFactor(
95  handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
96  return status;
97  }
98 
99  static cusolverStatus_t solve(
100  cusolverSpHandle_t handle,
101  int size,
102  const double * b,
103  double * x,
104  csrcholInfo_t & chol_info,
105  void * buffer)
106  {
107  cusolverStatus_t status = cusolverSpDcsrcholSolve(
108  handle, size, b, x, chol_info, buffer);
109  return status;
110  }
111  };
112 
113  template <>
114  struct FunctionMap<cuSOLVER,float>
115  {
116  static cusolverStatus_t bufferInfo(
117  cusolverSpHandle_t handle,
118  int size,
119  int nnz,
120  cusparseMatDescr_t & desc,
121  const float * values,
122  const int * rowPtr,
123  const int * colIdx,
124  csrcholInfo_t & chol_info,
125  size_t * internalDataInBytes,
126  size_t * workspaceInBytes)
127  {
128  cusolverStatus_t status =
129  cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
130  rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
131  return status;
132  }
133 
134  static cusolverStatus_t numeric(
135  cusolverSpHandle_t handle,
136  int size,
137  int nnz,
138  cusparseMatDescr_t & desc,
139  const float * values,
140  const int * rowPtr,
141  const int * colIdx,
142  csrcholInfo_t & chol_info,
143  void * buffer)
144  {
145  cusolverStatus_t status = cusolverSpScsrcholFactor(
146  handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
147  return status;
148  }
149 
150  static cusolverStatus_t solve(
151  cusolverSpHandle_t handle,
152  int size,
153  const float * b,
154  float * x,
155  csrcholInfo_t & chol_info,
156  void * buffer)
157  {
158  cusolverStatus_t status = cusolverSpScsrcholSolve(
159  handle, size, b, x, chol_info, buffer);
160  return status;
161  }
162  };
163 
164 #ifdef HAVE_TEUCHOS_COMPLEX
165  template <>
166  struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
167  {
168  static cusolverStatus_t bufferInfo(
169  cusolverSpHandle_t handle,
170  int size,
171  int nnz,
172  cusparseMatDescr_t & desc,
173  const void * values,
174  const int * rowPtr,
175  const int * colIdx,
176  csrcholInfo_t & chol_info,
177  size_t * internalDataInBytes,
178  size_t * workspaceInBytes)
179  {
180  typedef cuDoubleComplex scalar_t;
181  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
182  cusolverStatus_t status =
183  cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
184  cu_values, rowPtr, colIdx, chol_info,
185  internalDataInBytes, workspaceInBytes);
186  return status;
187  }
188 
189  static cusolverStatus_t numeric(
190  cusolverSpHandle_t handle,
191  int size,
192  int nnz,
193  cusparseMatDescr_t & desc,
194  const void * values,
195  const int * rowPtr,
196  const int * colIdx,
197  csrcholInfo_t & chol_info,
198  void * buffer)
199  {
200  typedef cuDoubleComplex scalar_t;
201  const scalar_t * cu_values =
202  reinterpret_cast<const scalar_t *>(values);
203  cusolverStatus_t status = cusolverSpZcsrcholFactor(
204  handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
205  return status;
206  }
207 
208  static cusolverStatus_t solve(
209  cusolverSpHandle_t handle,
210  int size,
211  const void * b,
212  void * x,
213  csrcholInfo_t & chol_info,
214  void * buffer)
215  {
216  typedef cuDoubleComplex scalar_t;
217  const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
218  scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
219  cusolverStatus_t status = cusolverSpZcsrcholSolve(
220  handle, size, cu_b, cu_x, chol_info, buffer);
221  return status;
222  }
223  };
224 
225  template <>
226  struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
227  {
228  static cusolverStatus_t bufferInfo(
229  cusolverSpHandle_t handle,
230  int size,
231  int nnz,
232  cusparseMatDescr_t & desc,
233  const void * values,
234  const int * rowPtr,
235  const int * colIdx,
236  csrcholInfo_t & chol_info,
237  size_t * internalDataInBytes,
238  size_t * workspaceInBytes)
239  {
240  typedef cuFloatComplex scalar_t;
241  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
242  cusolverStatus_t status =
243  cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
244  cu_values, rowPtr, colIdx, chol_info,
245  internalDataInBytes, workspaceInBytes);
246  return status;
247  }
248 
249  static cusolverStatus_t numeric(
250  cusolverSpHandle_t handle,
251  int size,
252  int nnz,
253  cusparseMatDescr_t & desc,
254  const void * values,
255  const int * rowPtr,
256  const int * colIdx,
257  csrcholInfo_t & chol_info,
258  void * buffer)
259  {
260  typedef cuFloatComplex scalar_t;
261  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
262  cusolverStatus_t status = cusolverSpCcsrcholFactor(
263  handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
264  return status;
265  }
266 
267  static cusolverStatus_t solve(
268  cusolverSpHandle_t handle,
269  int size,
270  const void * b,
271  void * x,
272  csrcholInfo_t & chol_info,
273  void * buffer)
274  {
275  typedef cuFloatComplex scalar_t;
276  const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
277  scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
278  cusolverStatus_t status = cusolverSpCcsrcholSolve(
279  handle, size, cu_b, cu_x, chol_info, buffer);
280  return status;
281  }
282  };
283 #endif
284 
285 } // end namespace Amesos2
286 
287 #endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
Declaration of Function mapping class for Amesos2.