1 : |
agomez |
1 |
#include "armijo.h" |
2 : |
|
|
|
3 : |
|
|
// ********************************************************
|
4 : |
|
|
// Functions for class Armijo
|
5 : |
|
|
// ********************************************************
|
6 : |
|
|
Armijo::Armijo() {
|
7 : |
|
|
cond = new LineSeekerCondition(this);
|
8 : |
|
|
beta = 0.3;
|
9 : |
|
|
sigma = 0.01;
|
10 : |
|
|
alpha = 0.0;
|
11 : |
|
|
}
|
12 : |
|
|
|
13 : |
|
|
Armijo::~Armijo() {
|
14 : |
|
|
delete cond;
|
15 : |
|
|
}
|
16 : |
|
|
|
17 : |
|
|
double Armijo::getAlpha() {
|
18 : |
|
|
return alpha;
|
19 : |
|
|
}
|
20 : |
|
|
|
21 : |
|
|
int Armijo::conditionSatisfied(double y) {
|
22 : |
|
|
int returnID = net->getReceiveID();
|
23 : |
|
|
// Vector temp(1);
|
24 : |
|
|
// temp = (net->getX(returnID));
|
25 : |
|
|
if ((initialf - y) >= (-sigma * net->getX(returnID)[0] * df))
|
26 : |
|
|
return 1;
|
27 : |
|
|
else
|
28 : |
|
|
return 0;
|
29 : |
|
|
}
|
30 : |
|
|
|
31 : |
|
|
void Armijo::doArmijo(const DoubleVector& v1, double fx, double dery,
|
32 : |
|
|
const DoubleVector& h, NetInterface *netI, double s1) {
|
33 : |
|
|
|
34 : |
|
|
int cond_satisfied;
|
35 : |
|
|
alpha = 0.0;
|
36 : |
|
|
power = -1;
|
37 : |
|
|
numVar = netI->getNumVarsInDataGroup();
|
38 : |
|
|
x = v1;
|
39 : |
|
|
s = s1;
|
40 : |
|
|
f = fx;
|
41 : |
|
|
initialx = v1;
|
42 : |
|
|
initialf = fx;
|
43 : |
|
|
hvec = h;
|
44 : |
|
|
net = netI;
|
45 : |
|
|
df = dery;
|
46 : |
|
|
|
47 : |
|
|
prepareNewLineSearch();
|
48 : |
|
|
initiateAlphas();
|
49 : |
|
|
cond_satisfied = net->sendAndReceiveSetData(cond);
|
50 : |
|
|
if (cond_satisfied == -1) {
|
51 : |
|
|
cerr << "Error in linesearch - cannot receive or send data\n";
|
52 : |
|
|
net->stopUsingDataGroup();
|
53 : |
|
|
exit(EXIT_FAILURE);
|
54 : |
|
|
} else if (cond_satisfied == 1) {
|
55 : |
|
|
// check this better should be working???
|
56 : |
|
|
} else {
|
57 : |
|
|
net->stopNetComm();
|
58 : |
|
|
exit(EXIT_FAILURE);
|
59 : |
|
|
}
|
60 : |
|
|
net->stopUsingDataGroup();
|
61 : |
|
|
}
|
62 : |
|
|
|
63 : |
|
|
int Armijo::computeConditionFunction() {
|
64 : |
|
|
int returnID, i, cond_satisfied = 0;
|
65 : |
|
|
int counter = net->getNumDataItemsSet();
|
66 : |
|
|
int newreturns = net->getNumDataItemsAnswered();
|
67 : |
|
|
double y;
|
68 : |
|
|
// Vector temp;
|
69 : |
|
|
|
70 : |
|
|
returnID = net->getReceiveID();
|
71 : |
|
|
if (returnID >= 0) {
|
72 : |
|
|
//temp = net->getX(returnID);
|
73 : |
|
|
y = net->getY(returnID);
|
74 : |
|
|
cond_satisfied = ((conditionSatisfied(y) == 1) && (f > y));
|
75 : |
|
|
if (cond_satisfied) {
|
76 : |
|
|
cout << "New optimum value f(x) = " << y << " at \n";
|
77 : |
|
|
f = y;
|
78 : |
|
|
power = returnID - 1;
|
79 : |
|
|
alpha = net->getX(returnID)[0];
|
80 : |
|
|
x = net->makeVector(net->getX(returnID));
|
81 : |
|
|
for (i = 0; i < x.Size() ; i++)
|
82 : |
|
|
cout << x[i] << " ";
|
83 : |
|
|
cout << endl << endl;
|
84 : |
|
|
}
|
85 : |
|
|
}
|
86 : |
|
|
|
87 : |
|
|
if (power == -1) {
|
88 : |
|
|
if (net->dataGroupFull()) {
|
89 : |
|
|
if ((0.8 * counter) <= newreturns) {
|
90 : |
|
|
// cannot set more data and have received 8/10 of all data - bailing out
|
91 : |
|
|
power = 1;
|
92 : |
|
|
alpha = 0.0;
|
93 : |
|
|
return 1;
|
94 : |
|
|
|
95 : |
|
|
} else
|
96 : |
|
|
return 0;
|
97 : |
|
|
|
98 : |
|
|
} else {
|
99 : |
|
|
// dataGroup not full
|
100 : |
|
|
i = setData();
|
101 : |
|
|
return cond_satisfied;
|
102 : |
|
|
}
|
103 : |
|
|
|
104 : |
|
|
} else if (power >= 0) {
|
105 : |
|
|
// not setting any more data, have already found optimal value
|
106 : |
|
|
if ((0.8 * counter) <= newreturns) {
|
107 : |
|
|
// have found optimal value and received 8/10 of all data set
|
108 : |
|
|
return 1;
|
109 : |
|
|
} else
|
110 : |
|
|
return 0;
|
111 : |
|
|
|
112 : |
|
|
} else
|
113 : |
|
|
return 1;
|
114 : |
|
|
}
|
115 : |
|
|
|
116 : |
|
|
void Armijo::prepareNewLineSearch() {
|
117 : |
|
|
net->startNewDataGroup(numAlpha, x, hvec);
|
118 : |
|
|
if (df > 0) {
|
119 : |
|
|
cerr << "Error in linesearch - bad derivative\n";
|
120 : |
|
|
net->stopUsingDataGroup();
|
121 : |
|
|
}
|
122 : |
|
|
DoubleVector tempx(1,0.0);
|
123 : |
|
|
// tempx[0] = 0.0;
|
124 : |
|
|
net->setDataPair(tempx, initialf);
|
125 : |
|
|
}
|
126 : |
|
|
|
127 : |
|
|
void Armijo::initiateAlphas() {
|
128 : |
|
|
int i = net->getTotalNumProc();
|
129 : |
|
|
int j;
|
130 : |
|
|
DoubleVector tempx(1);
|
131 : |
|
|
assert(beta > 0.0);
|
132 : |
|
|
assert(beta <= 0.5);
|
133 : |
|
|
for (j = 0; j < i; j++) {
|
134 : |
|
|
tempx[0] = pow(beta, j) * s;
|
135 : |
|
|
net->setX(tempx);
|
136 : |
|
|
}
|
137 : |
|
|
}
|
138 : |
|
|
|
139 : |
|
|
void Armijo::setSigma(double s) {
|
140 : |
|
|
sigma = s;
|
141 : |
|
|
}
|
142 : |
|
|
|
143 : |
|
|
void Armijo::setBeta(double b) {
|
144 : |
|
|
beta = b;
|
145 : |
|
|
}
|
146 : |
|
|
|
147 : |
|
|
double Armijo::getBeta() {
|
148 : |
|
|
return beta;
|
149 : |
|
|
}
|
150 : |
|
|
|
151 : |
|
|
int Armijo::getPower() {
|
152 : |
|
|
return power;
|
153 : |
|
|
}
|
154 : |
|
|
|
155 : |
|
|
int Armijo::outstandingRequests() {
|
156 : |
|
|
int out;
|
157 : |
|
|
int pending = net->getNumNotAns();
|
158 : |
|
|
out = (pending > 0);
|
159 : |
|
|
return out;
|
160 : |
|
|
}
|
161 : |
|
|
|
162 : |
|
|
int Armijo::setData() {
|
163 : |
|
|
DoubleVector tempx(1);
|
164 : |
|
|
int counter = net->getNumDataItemsSet() - 1;
|
165 : |
|
|
double a = pow(beta, counter) * s;
|
166 : |
|
|
tempx[0] = a;
|
167 : |
|
|
int ok = -1;
|
168 : |
|
|
if (net->dataGroupFull()) {
|
169 : |
|
|
cerr << "Error in armijo - have set too many values\n";
|
170 : |
|
|
ok = 0;
|
171 : |
|
|
} else {
|
172 : |
|
|
net->setXFirstToSend(tempx);
|
173 : |
|
|
ok = 1;
|
174 : |
|
|
}
|
175 : |
|
|
return ok;
|
176 : |
|
|
}
|
177 : |
|
|
|