Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
dfad_dfad_example.cpp
Go to the documentation of this file.
1// @HEADER
2// ***********************************************************************
3//
4// Sacado Package
5// Copyright (2006) Sandia Corporation
6//
7// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8// the U.S. Government retains certain rights in this software.
9//
10// This library is free software; you can redistribute it and/or modify
11// it under the terms of the GNU Lesser General Public License as
12// published by the Free Software Foundation; either version 2.1 of the
13// License, or (at your option) any later version.
14//
15// This library is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18// Lesser General Public License for more details.
19//
20// You should have received a copy of the GNU Lesser General Public
21// License along with this library; if not, write to the Free Software
22// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23// USA
24// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25// (etphipp@sandia.gov).
26//
27// ***********************************************************************
28// @HEADER
29
30// dfad_dfad_example
31//
32// usage:
33// dfad_dfad_example
34//
35// output:
36// prints the results of computing the second derivative a simple function // with forward nested forward mode AD using the Sacado::Fad::DFad class
37// (uses dynamic memory allocation for number of derivative components).
38
39#include <iostream>
40#include <iomanip>
41
42#include "Sacado.hpp"
43
44// The function to differentiate
45template <typename ScalarT>
46ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
47 ScalarT r = c*std::log(b+1.)/std::sin(a);
48 return r;
49}
50
51// The analytic derivative of func(a,b,c) with respect to a and b
52void func_deriv(double a, double b, double c, double& drda, double& drdb)
53{
54 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
55 drdb = c / ((b+1.)*std::sin(a));
56}
57
58// The analytic second derivative of func(a,b,c) with respect to a and b
59void func_deriv2(double a, double b, double c, double& d2rda2, double& d2rdb2,
60 double& d2rdadb)
61{
62 d2rda2 = c*std::log(b+1.)/std::sin(a) + 2.*(c*std::log(b+1.)/std::pow(std::sin(a),3.))*std::pow(std::cos(a),2.);
63 d2rdb2 = -c / (std::pow(b+1.,2.)*std::sin(a));
64 d2rdadb = -c / ((b+1.)*std::pow(std::sin(a),2.))*std::cos(a);
65}
66
67int main(int argc, char **argv)
68{
69 double pi = std::atan(1.0)*4.0;
70
71 // Values of function arguments
72 double a = pi/4;
73 double b = 2.0;
74 double c = 3.0;
75
76 // Number of independent variables
77 int num_deriv = 2;
78
79 // Fad objects
81 Sacado::Fad::DFad<DFadType> afad(num_deriv, 0, a);
82 Sacado::Fad::DFad<DFadType> bfad(num_deriv, 1, b);
85
86 afad.val() = Sacado::Fad::DFad<double>(num_deriv, 0, a);
87 bfad.val() = Sacado::Fad::DFad<double>(num_deriv, 1, b);
88
89 // Compute function
90 double r = func(a, b, c);
91
92 // Compute derivative analytically
93 double drda, drdb;
94 func_deriv(a, b, c, drda, drdb);
95
96 // Compute second derivative analytically
97 double d2rda2, d2rdb2, d2rdadb;
98 func_deriv2(a, b, c, d2rda2, d2rdb2, d2rdadb);
99
100 // Compute function and derivative with AD
101 rfad = func(afad, bfad, cfad);
102
103 // Extract value and derivatives
104 double r_ad = rfad.val().val(); // r
105 double drda_ad = rfad.dx(0).val(); // dr/da
106 double drdb_ad = rfad.dx(1).val(); // dr/db
107 double d2rda2_ad = rfad.dx(0).dx(0); // d^2r/da^2
108 double d2rdadb_ad = rfad.dx(0).dx(1); // d^2r/dadb
109 double d2rdbda_ad = rfad.dx(1).dx(0); // d^2r/dbda
110 double d2rdb2_ad = rfad.dx(1).dx(1); // d^2/db^2
111
112 // Print the results
113 int p = 4;
114 int w = p+7;
115 std::cout.setf(std::ios::scientific);
116 std::cout.precision(p);
117 std::cout << " r = " << std::setw(w) << r << " (original) == "
118 << std::setw(w) << r_ad << " (AD) Error = " << std::setw(w)
119 << r - r_ad << std::endl
120 << " dr/da = " << std::setw(w) << drda << " (analytic) == "
121 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
122 << drda - drda_ad << std::endl
123 << " dr/db = " << std::setw(w) << drdb << " (analytic) == "
124 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
125 << drdb - drdb_ad << std::endl
126 << "d^2r/da^2 = " << std::setw(w) << d2rda2 << " (analytic) == "
127 << std::setw(w) << d2rda2_ad << " (AD) Error = " << std::setw(w)
128 << d2rda2 - d2rda2_ad << std::endl
129 << "d^2r/db^2 = " << std::setw(w) << d2rdb2 << " (analytic) == "
130 << std::setw(w) << d2rdb2_ad << " (AD) Error = " << std::setw(w)
131 << d2rdb2 - d2rdb2_ad << std::endl
132 << "d^2r/dadb = " << std::setw(w) << d2rdadb << " (analytic) == "
133 << std::setw(w) << d2rdadb_ad << " (AD) Error = " << std::setw(w)
134 << d2rdadb - d2rdadb_ad << std::endl
135 << "d^2r/dbda = " << std::setw(w) << d2rdadb << " (analytic) == "
136 << std::setw(w) << d2rdbda_ad << " (AD) Error = " << std::setw(w)
137 << d2rdadb - d2rdbda_ad << std::endl;
138
139 double tol = 1.0e-14;
140 if (std::fabs(r - r_ad) < tol &&
141 std::fabs(drda - drda_ad) < tol &&
142 std::fabs(drdb - drdb_ad) < tol &&
143 std::fabs(d2rda2 - d2rda2_ad) < tol &&
144 std::fabs(d2rdb2 - d2rdb2_ad) < tol &&
145 std::fabs(d2rdadb - d2rdadb_ad) < tol) {
146 std::cout << "\nExample passed!" << std::endl;
147 return 0;
148 }
149 else {
150 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
151 return 1;
152 }
153}
Sacado::Fad::DFad< double > DFadType
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
int main()
Definition: ad_example.cpp:191
void func_deriv(double a, double b, double c, double &drda, double &drdb)
void func_deriv2(double a, double b, double c, double &d2rda2, double &d2rdb2, double &d2rdadb)
ScalarT func(const ScalarT &a, const ScalarT &b, const ScalarT &c)
const char * p
const double tol