1 #ifndef VERTEXCFD_LINEARSOLVERS_CUSOLVERGLU_HPP
2 #define VERTEXCFD_LINEARSOLVERS_CUSOLVERGLU_HPP
4 #include "VertexCFD_LinearSolvers_CusolverNonpublic.hpp"
5 #include "VertexCFD_LinearSolvers_LocalDirectSolver.hpp"
7 #include <Teuchos_RCP.hpp>
8 #include <Tpetra_CrsMatrix.hpp>
9 #include <Tpetra_RowMatrix.hpp>
11 #include <cusolverSp.h>
12 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
13 #include <thrust/device_vector.h>
17 namespace LinearSolvers
40 Teuchos::RCP<const Tpetra::RowMatrix<>> _A;
43 thrust::device_vector<int> _A_rowptr;
44 thrust::device_vector<int> _A_colind;
45 thrust::device_vector<double> _A_values;
48 std::vector<int> _A_rowptr_host;
49 std::vector<int> _A_colind_host;
50 std::vector<double> _A_values_host;
53 double _pivot_threshold;
57 thrust::device_vector<char> _work;
60 cusolverSpHandle_t _handle;
61 csrluInfoHost_t _lu_info;
63 cusparseMatDescr_t _A_descr, _M_descr;
75 void setMatrix(Teuchos::RCP<
const Tpetra::RowMatrix<>> A)
override;
78 void initialize()
override;
79 void compute()
override;
83 solve(
const Tpetra::MultiVector<>& b, Tpetra::MultiVector<>& x)
override;
88 check_status(cusolverStatus_t stat,
const std::string& identifier)
const;
91 int num_local_rows()
const;
92 std::size_t num_local_entries()
const;
93 std::size_t max_entries_per_row()
const;
101 #endif // VERTEXCFD_LINEARSOLVERS_CUSOLVERGLU_HPP