ROL
ROL_FDivergence.hpp
Go to the documentation of this file.
1// @HEADER
2// ************************************************************************
3//
4// Rapid Optimization Library (ROL) Package
5// Copyright (2014) Sandia Corporation
6//
7// Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8// license for use of this work by or on behalf of the U.S. Government.
9//
10// Redistribution and use in source and binary forms, with or without
11// modification, are permitted provided that the following conditions are
12// met:
13//
14// 1. Redistributions of source code must retain the above copyright
15// notice, this list of conditions and the following disclaimer.
16//
17// 2. Redistributions in binary form must reproduce the above copyright
18// notice, this list of conditions and the following disclaimer in the
19// documentation and/or other materials provided with the distribution.
20//
21// 3. Neither the name of the Corporation nor the names of the
22// contributors may be used to endorse or promote products derived from
23// this software without specific prior written permission.
24//
25// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36//
37// Questions? Contact lead developers:
38// Drew Kouri (dpkouri@sandia.gov) and
39// Denis Ridzal (dridzal@sandia.gov)
40//
41// ************************************************************************
42// @HEADER
43
44#ifndef ROL_FDIVERGENCE_HPP
45#define ROL_FDIVERGENCE_HPP
46
48#include "ROL_Types.hpp"
49
84namespace ROL {
85
86template<class Real>
87class FDivergence : public RandVarFunctional<Real> {
88private:
89 Real thresh_;
90
91 Real valLam_;
93 Real valMu_;
94 Real valMu2_;
95
96 using RandVarFunctional<Real>::val_;
97 using RandVarFunctional<Real>::gv_;
98 using RandVarFunctional<Real>::g_;
99 using RandVarFunctional<Real>::hv_;
101
102 using RandVarFunctional<Real>::point_;
103 using RandVarFunctional<Real>::weight_;
104
109
110 void checkInputs(void) const {
111 Real zero(0);
112 ROL_TEST_FOR_EXCEPTION((thresh_ <= zero), std::invalid_argument,
113 ">>> ERROR (ROL::FDivergence): Threshold must be positive!");
114 }
115
116public:
121 FDivergence(const Real thresh) : RandVarFunctional<Real>(), thresh_(thresh),
122 valLam_(0),valLam2_(0), valMu_(0), valMu2_(0) {
123 checkInputs();
124 }
125
134 FDivergence(ROL::ParameterList &parlist) : RandVarFunctional<Real>(),
135 valLam_(0),valLam2_(0), valMu_(0), valMu2_(0) {
136 ROL::ParameterList &list
137 = parlist.sublist("SOL").sublist("Risk Measure").sublist("F-Divergence");
138 thresh_ = list.get<Real>("Threshold");
139 checkInputs();
140 }
141
149 virtual Real Fprimal(Real x, int deriv = 0) = 0;
150
163 virtual Real Fdual(Real x, int deriv = 0) = 0;
164
165 bool check(std::ostream &outStream = std::cout) const {
166 const Real tol(std::sqrt(ROL_EPSILON<Real>()));
167 bool flag = true;
168
169 Real x = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
170 Real t = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
171 Real fp = Fprimal(x);
172 Real fd = Fdual(t);
173 outStream << "Check Fenchel-Young Inequality: F(x) + F*(t) >= xt" << std::endl;
174 outStream << "x = " << x << std::endl;
175 outStream << "t = " << t << std::endl;
176 outStream << "F(x) = " << fp << std::endl;
177 outStream << "F*(t) = " << fd << std::endl;
178 outStream << "Is Valid? " << (fp+fd >= x*t) << std::endl;
179 flag = (fp+fd >= x*t) ? flag : false;
180
181 x = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
182 t = Fprimal(x,1);
183 fp = Fprimal(x);
184 fd = Fdual(t);
185 outStream << "Check Fenchel-Young Equality: F(x) + F(t) = xt for t = d/dx F(x)" << std::endl;
186 outStream << "x = " << x << std::endl;
187 outStream << "t = " << t << std::endl;
188 outStream << "F(x) = " << fp << std::endl;
189 outStream << "F*(t) = " << fd << std::endl;
190 outStream << "Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
191 flag = (std::abs(fp+fd - x*t)<=tol) ? flag : false;
192
193 t = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
194 x = Fdual(t,1);
195 fp = Fprimal(x);
196 fd = Fdual(t);
197 outStream << "Check Fenchel-Young Equality: F(x) + F(t) = xt for x = d/dt F*(t)" << std::endl;
198 outStream << "x = " << x << std::endl;
199 outStream << "t = " << t << std::endl;
200 outStream << "F(x) = " << fp << std::endl;
201 outStream << "F*(t) = " << fd << std::endl;
202 outStream << "Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
203 flag = (std::abs(fp+fd - x*t)<=tol) ? flag : false;
204
205 return flag;
206 }
207
208 void initialize(const Vector<Real> &x) {
210 valLam_ = 0; valLam2_ = 0; valMu_ = 0; valMu2_ = 0;
211 }
212
213 // Value update and get functions
215 const Vector<Real> &x,
216 const std::vector<Real> &xstat,
217 Real &tol) {
218 Real val = computeValue(obj,x,tol);
219 Real xlam = xstat[0];
220 Real xmu = xstat[1];
221 Real r = Fdual((val-xmu)/xlam,0);
222 val_ += weight_ * r;
223 }
224
225 Real getValue(const Vector<Real> &x,
226 const std::vector<Real> &xstat,
227 SampleGenerator<Real> &sampler) {
228 Real val(0);
229 sampler.sumAll(&val_,&val,1);
230 Real xlam = xstat[0];
231 Real xmu = xstat[1];
232 return xlam*(thresh_ + val) + xmu;
233 }
234
235 // Gradient update and get functions
237 const Vector<Real> &x,
238 const std::vector<Real> &xstat,
239 Real &tol) {
240 Real val = computeValue(obj,x,tol);
241 Real xlam = xstat[0];
242 Real xmu = xstat[1];
243 Real inp = (val-xmu)/xlam;
244 Real r0 = Fdual(inp,0), r1 = Fdual(inp,1);
245
246 if (std::abs(r0) >= ROL_EPSILON<Real>()) {
247 val_ += weight_ * r0;
248 }
249 if (std::abs(r1) >= ROL_EPSILON<Real>()) {
250 valLam_ -= weight_ * r1 * inp;
251 valMu_ -= weight_ * r1;
252 computeGradient(*dualVector_,obj,x,tol);
253 g_->axpy(weight_*r1,*dualVector_);
254 }
255 }
256
258 std::vector<Real> &gstat,
259 const Vector<Real> &x,
260 const std::vector<Real> &xstat,
261 SampleGenerator<Real> &sampler) {
262 std::vector<Real> mygval(3), gval(3);
263 mygval[0] = val_;
264 mygval[1] = valLam_;
265 mygval[2] = valMu_;
266 sampler.sumAll(&mygval[0],&gval[0],3);
267
268 gstat[0] = thresh_ + gval[0] + gval[1];
269 gstat[1] = static_cast<Real>(1) + gval[2];
270
271 sampler.sumAll(*g_,g);
272 }
273
275 const Vector<Real> &v,
276 const std::vector<Real> &vstat,
277 const Vector<Real> &x,
278 const std::vector<Real> &xstat,
279 Real &tol) {
280 Real val = computeValue(obj,x,tol);
281 Real xlam = xstat[0];
282 Real xmu = xstat[1];
283 Real vlam = vstat[0];
284 Real vmu = vstat[1];
285 Real inp = (val-xmu)/xlam;
286 Real r1 = Fdual(inp,1), r2 = Fdual(inp,2);
287 if (std::abs(r2) >= ROL_EPSILON<Real>()) {
288 Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
289 val_ += weight_ * r2 * inp;
290 valLam_ += weight_ * r2 * inp * inp;
291 valLam2_ -= weight_ * r2 * gv * inp;
292 valMu_ += weight_ * r2;
293 valMu2_ -= weight_ * r2 * gv;
294 hv_->axpy(weight_ * r2 * (gv - vmu - vlam*inp)/xlam, *dualVector_);
295 }
296 if (std::abs(r1) >= ROL_EPSILON<Real>()) {
297 computeHessVec(*dualVector_,obj,v,x,tol);
298 hv_->axpy(weight_ * r1, *dualVector_);
299 }
300 }
301
303 std::vector<Real> &hvstat,
304 const Vector<Real> &v,
305 const std::vector<Real> &vstat,
306 const Vector<Real> &x,
307 const std::vector<Real> &xstat,
308 SampleGenerator<Real> &sampler) {
309 std::vector<Real> myhval(5), hval(5);
310 myhval[0] = val_;
311 myhval[1] = valLam_;
312 myhval[2] = valLam2_;
313 myhval[3] = valMu_;
314 myhval[4] = valMu2_;
315 sampler.sumAll(&myhval[0],&hval[0],5);
316
317 std::vector<Real> stat(2);
318 Real xlam = xstat[0];
319 //Real xmu = xstat[1];
320 Real vlam = vstat[0];
321 Real vmu = vstat[1];
322 hvstat[0] = (vlam * hval[1] + vmu * hval[0] + hval[2])/xlam;
323 hvstat[1] = (vlam * hval[0] + vmu * hval[3] + hval[4])/xlam;
324
325 sampler.sumAll(*hv_,hv);
326 }
327};
328
329}
330
331#endif
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
Contains definitions of custom data types in ROL.
Provides a general interface for the F-divergence distributionally robust expectation.
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for gradient computation.
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal storage for value computation.
FDivergence(ROL::ParameterList &parlist)
Constructor.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure value.
void initialize(const Vector< Real > &x)
Initialize temporary variables.
FDivergence(const Real thresh)
Constructor.
void checkInputs(void) const
bool check(std::ostream &outStream=std::cout) const
virtual Real Fprimal(Real x, int deriv=0)=0
Implementation of the scalar primal F function.
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for Hessian-time-a-vector computation.
virtual Real Fdual(Real x, int deriv=0)=0
Implementation of the scalar dual F function.
Provides the interface to evaluate objective functions.
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
Real computeValue(Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > g_
virtual void initialize(const Vector< Real > &x)
Initialize temporary variables.
void computeHessVec(Vector< Real > &hv, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > hv_
void computeGradient(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > dualVector_
Real computeGradVec(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
void sumAll(Real *input, Real *output, int dim) const
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:84