[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_deprec.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008 by Ullrich Koethe */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35
36#ifndef VIGRA_RANDOM_FOREST_DEPREC_HXX
37#define VIGRA_RANDOM_FOREST_DEPREC_HXX
38
39#include <algorithm>
40#include <map>
41#include <numeric>
42#include <iostream>
43#include <ctime>
44#include <cstdlib>
45#include "vigra/mathutil.hxx"
46#include "vigra/array_vector.hxx"
47#include "vigra/sized_int.hxx"
48#include "vigra/matrix.hxx"
49#include "vigra/random.hxx"
50#include "vigra/functorexpression.hxx"
51
52
53namespace vigra
54{
55
56/** \addtogroup MachineLearning
57**/
58//@{
59
60namespace detail
61{
62
63template<class DataMatrix>
64class RandomForestDeprecFeatureSorter
65{
66 DataMatrix const & data_;
67 MultiArrayIndex sortColumn_;
68
69 public:
70
71 RandomForestDeprecFeatureSorter(DataMatrix const & data, MultiArrayIndex sortColumn)
72 : data_(data),
73 sortColumn_(sortColumn)
74 {}
75
76 void setColumn(MultiArrayIndex sortColumn)
77 {
78 sortColumn_ = sortColumn;
79 }
80
81 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
82 {
83 return data_(l, sortColumn_) < data_(r, sortColumn_);
84 }
85};
86
87template<class LabelArray>
88class RandomForestDeprecLabelSorter
89{
90 LabelArray const & labels_;
91
92 public:
93
94 RandomForestDeprecLabelSorter(LabelArray const & labels)
95 : labels_(labels)
96 {}
97
98 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
99 {
100 return labels_[l] < labels_[r];
101 }
102};
103
104template <class CountArray>
105class RandomForestDeprecClassCounter
106{
107 ArrayVector<int> const & labels_;
108 CountArray & counts_;
109
110 public:
111
112 RandomForestDeprecClassCounter(ArrayVector<int> const & labels, CountArray & counts)
113 : labels_(labels),
114 counts_(counts)
115 {
116 reset();
117 }
118
119 void reset()
120 {
121 counts_.init(0);
122 }
123
124 void operator()(MultiArrayIndex l) const
125 {
126 ++counts_[labels_[l]];
127 }
128};
129
130struct DecisionTreeDeprecCountNonzeroFunctor
131{
132 double operator()(double old, double other) const
133 {
134 if(other != 0.0)
135 ++old;
136 return old;
137 }
138};
139
140struct DecisionTreeDeprecNode
141{
142 DecisionTreeDeprecNode(int t, MultiArrayIndex bestColumn)
143 : thresholdIndex(t), splitColumn(bestColumn)
144 {}
145
146 int children[2];
147 int thresholdIndex;
148 Int32 splitColumn;
149};
150
151template <class INT>
152struct DecisionTreeDeprecNodeProxy
153{
154 DecisionTreeDeprecNodeProxy(ArrayVector<INT> const & tree, INT n)
155 : node(const_cast<ArrayVector<INT> &>(tree).begin()+n)
156 {}
157
158 INT & child(INT l) const
159 {
160 return node[l];
161 }
162
163 INT & decisionWeightsIndex() const
164 {
165 return node[2];
166 }
167
168 typename ArrayVector<INT>::iterator decisionColumns() const
169 {
170 return node+3;
171 }
172
173 mutable typename ArrayVector<INT>::iterator node;
174};
175
176struct DecisionTreeDeprecAxisSplitFunctor
177{
178 ArrayVector<Int32> splitColumns;
179 ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights;
180 double threshold;
181 double totalCounts[2], bestTotalCounts[2];
182 int mtry, classCount, bestSplitColumn;
183 bool pure[2], isWeighted;
184
185 void init(int mtry, int cols, int classCount, ArrayVector<double> const & weights)
186 {
187 this->mtry = mtry;
188 splitColumns.resize(cols);
189 for(int k=0; k<cols; ++k)
190 splitColumns[k] = k;
191
192 this->classCount = classCount;
193 classCounts.resize(classCount);
194 currentCounts[0].resize(classCount);
195 currentCounts[1].resize(classCount);
196 bestCounts[0].resize(classCount);
197 bestCounts[1].resize(classCount);
198
199 isWeighted = weights.size() > 0;
200 if(isWeighted)
201 classWeights = weights;
202 else
203 classWeights.resize(classCount, 1.0);
204 }
205
206 bool isPure(int k) const
207 {
208 return pure[k];
209 }
210
211 unsigned int totalCount(int k) const
212 {
213 return (unsigned int)bestTotalCounts[k];
214 }
215
216 int sizeofNode() const { return 4; }
217
218 int writeSplitParameters(ArrayVector<Int32> & tree,
219 ArrayVector<double> &terminalWeights)
220 {
221 int currentWeightIndex = terminalWeights.size();
222 terminalWeights.push_back(threshold);
223
224 int currentNodeIndex = tree.size();
225 tree.push_back(-1); // left child
226 tree.push_back(-1); // right child
227 tree.push_back(currentWeightIndex);
228 tree.push_back(bestSplitColumn);
229
230 return currentNodeIndex;
231 }
232
233 void writeWeights(int l, ArrayVector<double> &terminalWeights)
234 {
235 for(int k=0; k<classCount; ++k)
236 terminalWeights.push_back(isWeighted
237 ? bestCounts[l][k]
238 : bestCounts[l][k] / totalCount(l));
239 }
240
241 template <class U, class C, class AxesIterator, class WeightIterator>
242 bool decideAtNode(MultiArrayView<2, U, C> const & features,
243 AxesIterator a, WeightIterator w) const
244 {
245 return (features(0, *a) < *w);
246 }
247
248 template <class U, class C, class IndexIterator, class Random>
249 IndexIterator findBestSplit(MultiArrayView<2, U, C> const & features,
250 ArrayVector<int> const & labels,
251 IndexIterator indices, int exampleCount,
252 Random & randint);
253
254};
255
256
257template <class U, class C, class IndexIterator, class Random>
258IndexIterator
259DecisionTreeDeprecAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> const & features,
260 ArrayVector<int> const & labels,
261 IndexIterator indices, int exampleCount,
262 Random & randint)
263{
264 // select columns to be tried for split
265 for(int k=0; k<mtry; ++k)
266 std::swap(splitColumns[k], splitColumns[k+randint(columnCount(features)-k)]);
267
268 RandomForestDeprecFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0);
269 RandomForestDeprecClassCounter<ArrayVector<double> > counter(labels, classCounts);
270 std::for_each(indices, indices+exampleCount, counter);
271
272 // find the best gini index
273 double minGini = NumericTraits<double>::max();
274 IndexIterator bestSplit = indices;
275 for(int k=0; k<mtry; ++k)
276 {
277 sorter.setColumn(splitColumns[k]);
278 std::sort(indices, indices+exampleCount, sorter);
279
280 currentCounts[0].init(0);
281 std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(),
282 currentCounts[1].begin(), std::multiplies<double>());
283 totalCounts[0] = 0;
284 totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0);
285 for(int m = 0; m < exampleCount-1; ++m)
286 {
287 int label = labels[indices[m]];
288 double w = classWeights[label];
289 currentCounts[0][label] += w;
290 totalCounts[0] += w;
291 currentCounts[1][label] -= w;
292 totalCounts[1] -= w;
293
294 if (m < exampleCount-2 &&
295 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k]))
296 continue ;
297
298 double gini = 0.0;
299 if(classCount == 2)
300 {
301 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] +
302 currentCounts[1][0]*currentCounts[1][1] / totalCounts[1];
303 }
304 else
305 {
306 for(int l=0; l<classCount; ++l)
307 gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) +
308 currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]);
309 }
310 if(gini < minGini)
311 {
312 minGini = gini;
313 bestSplit = indices+m;
314 bestSplitColumn = splitColumns[k];
315 bestCounts[0] = currentCounts[0];
316 bestCounts[1] = currentCounts[1];
317 }
318 }
319
320
321
322 }
323 //std::cerr << minGini << " " << bestSplitColumn << std::endl;
324 // split using the best feature
325 sorter.setColumn(bestSplitColumn);
326 std::sort(indices, indices+exampleCount, sorter);
327
328 for(int k=0; k<2; ++k)
329 {
330 bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0);
331 }
332
333 threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0;
334 ++bestSplit;
335
336 counter.reset();
337 std::for_each(indices, bestSplit, counter);
338 pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
339 counter.reset();
340 std::for_each(bestSplit, indices+exampleCount, counter);
341 pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
342
343 return bestSplit;
344}
345
346enum { DecisionTreeDeprecNoParent = -1 };
347
348template <class Iterator>
349struct DecisionTreeDeprecStackEntry
350{
351 DecisionTreeDeprecStackEntry(Iterator i, int c,
352 int lp = DecisionTreeDeprecNoParent, int rp = DecisionTreeDeprecNoParent)
353 : indices(i), exampleCount(c),
354 leftParent(lp), rightParent(rp)
355 {}
356
357 Iterator indices;
358 int exampleCount, leftParent, rightParent;
359};
360
361class DecisionTreeDeprec
362{
363 public:
364 typedef Int32 TreeInt;
365 ArrayVector<TreeInt> tree_;
366 ArrayVector<double> terminalWeights_;
367 unsigned int classCount_;
368 DecisionTreeDeprecAxisSplitFunctor split;
369
370 public:
371
372
373 DecisionTreeDeprec(unsigned int classCount)
374 : classCount_(classCount)
375 {}
376
377 void reset(unsigned int classCount = 0)
378 {
379 if(classCount)
380 classCount_ = classCount;
381 tree_.clear();
382 terminalWeights_.clear();
383 }
384
385 template <class U, class C, class Iterator, class Options, class Random>
386 void learn(MultiArrayView<2, U, C> const & features,
387 ArrayVector<int> const & labels,
388 Iterator indices, int exampleCount,
389 Options const & options,
390 Random & randint);
391
392 template <class U, class C>
393 ArrayVector<double>::const_iterator
394 predict(MultiArrayView<2, U, C> const & features) const
395 {
396 int nodeindex = 0;
397 for(;;)
398 {
399 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
400 nodeindex = split.decideAtNode(features, node.decisionColumns(),
401 terminalWeights_.begin() + node.decisionWeightsIndex())
402 ? node.child(0)
403 : node.child(1);
404 if(nodeindex <= 0)
405 return terminalWeights_.begin() + (-nodeindex);
406 }
407 }
408
409 template <class U, class C>
410 int
411 predictLabel(MultiArrayView<2, U, C> const & features) const
412 {
413 ArrayVector<double>::const_iterator weights = predict(features);
414 return argMax(weights, weights+classCount_) - weights;
415 }
416
417 template <class U, class C>
418 int
419 leafID(MultiArrayView<2, U, C> const & features) const
420 {
421 int nodeindex = 0;
422 for(;;)
423 {
424 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
425 nodeindex = split.decideAtNode(features, node.decisionColumns(),
426 terminalWeights_.begin() + node.decisionWeightsIndex())
427 ? node.child(0)
428 : node.child(1);
429 if(nodeindex <= 0)
430 return -nodeindex;
431 }
432 }
433
434 void depth(int & maxDep, int & interiorCount, int & leafCount, int k = 0, int d = 1) const
435 {
436 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
437 ++interiorCount;
438 ++d;
439 for(int l=0; l<2; ++l)
440 {
441 int child = node.child(l);
442 if(child > 0)
443 depth(maxDep, interiorCount, leafCount, child, d);
444 else
445 {
446 ++leafCount;
447 if(maxDep < d)
448 maxDep = d;
449 }
450 }
451 }
452
453 void printStatistics(std::ostream & o) const
454 {
455 int maxDep = 0, interiorCount = 0, leafCount = 0;
456 depth(maxDep, interiorCount, leafCount);
457
458 o << "interior nodes: " << interiorCount <<
459 ", terminal nodes: " << leafCount <<
460 ", depth: " << maxDep << "\n";
461 }
462
463 void print(std::ostream & o, int k = 0, std::string s = "") const
464 {
465 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
466 o << s << (*node.decisionColumns()) << " " << terminalWeights_[node.decisionWeightsIndex()] << "\n";
467
468 for(int l=0; l<2; ++l)
469 {
470 int child = node.child(l);
471 if(child <= 0)
472 o << s << " weights " << terminalWeights_[-child] << " "
473 << terminalWeights_[-child+1] << "\n";
474 else
475 print(o, child, s+" ");
476 }
477 }
478};
479
480
481template <class U, class C, class Iterator, class Options, class Random>
482void DecisionTreeDeprec::learn(MultiArrayView<2, U, C> const & features,
483 ArrayVector<int> const & labels,
484 Iterator indices, int exampleCount,
485 Options const & options,
486 Random & randint)
487{
488 ArrayVector<double> const & classLoss = options.class_weights;
489
490 vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_,
491 "DecisionTreeDeprec2::learn(): class weights array has wrong size.");
492
493 reset();
494
495 unsigned int mtry = options.mtry;
496 MultiArrayIndex cols = columnCount(features);
497
498 split.init(mtry, cols, classCount_, classLoss);
499
500 typedef DecisionTreeDeprecStackEntry<Iterator> Entry;
501 ArrayVector<Entry> stack;
502 stack.push_back(Entry(indices, exampleCount));
503
504 while(!stack.empty())
505 {
506// std::cerr << "*";
507 indices = stack.back().indices;
508 exampleCount = stack.back().exampleCount;
509 int leftParent = stack.back().leftParent,
510 rightParent = stack.back().rightParent;
511
512 stack.pop_back();
513
514 Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint);
515
516
517 int currentNode = split.writeSplitParameters(tree_, terminalWeights_);
518
519 if(leftParent != DecisionTreeDeprecNoParent)
520 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode;
521 if(rightParent != DecisionTreeDeprecNoParent)
522 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode;
523 leftParent = currentNode;
524 rightParent = DecisionTreeDeprecNoParent;
525
526 for(int l=0; l<2; ++l)
527 {
528
529 if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size)
530 {
531 // sample is still large enough and not yet perfectly separated => split
532 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent));
533 }
534 else
535 {
536 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size();
537
538 split.writeWeights(l, terminalWeights_);
539 }
540 std::swap(leftParent, rightParent);
541 indices = bestSplit;
542 }
543 }
544// std::cerr << "\n";
545}
546
547} // namespace detail
548
549class RandomForestOptionsDeprec
550{
551 public:
552 /** Initialize all options with default values.
553 */
554 RandomForestOptionsDeprec()
555 : training_set_proportion(1.0),
556 mtry(0),
557 min_split_node_size(1),
558 training_set_size(0),
559 sample_with_replacement(true),
560 sample_classes_individually(false),
561 treeCount(255)
562 {}
563
564 /** Number of features considered in each node.
565
566 If \a n is 0 (the default), the number of features tried in every node
567 is determined by the square root of the total number of features.
568 According to Breiman, this quantity should always be optimized by means
569 of the out-of-bag error.<br>
570 Default: 0 (use <tt>sqrt(columnCount(featureMatrix))</tt>)
571 */
572 RandomForestOptionsDeprec & featuresPerNode(unsigned int n)
573 {
574 mtry = n;
575 return *this;
576 }
577
578 /** How to sample the subset of the training data for each tree.
579
580 Each tree is only trained with a subset of the entire training data.
581 If \a r is <tt>true</tt>, this subset is sampled from the entire training set with
582 replacement.<br>
583 Default: <tt>true</tt> (use sampling with replacement))
584 */
585 RandomForestOptionsDeprec & sampleWithReplacement(bool r)
586 {
587 sample_with_replacement = r;
588 return *this;
589 }
590
591 RandomForestOptionsDeprec & setTreeCount(unsigned int cnt)
592 {
593 treeCount = cnt;
594 return *this;
595 }
596 /** Proportion of training examples used for each tree.
597
598 If \a p is 1.0 (the default), and samples are drawn with replacement,
599 the training set of each tree will contain as many examples as the entire
600 training set, but some are drawn multiply and others not at all. On average,
601 each tree is actually trained on about 65% of the examples in the full
602 training set. Changing the proportion makes mainly sense when
603 sampleWithReplacement() is set to <tt>false</tt>. trainingSetSizeProportional() gets
604 overridden by trainingSetSizeAbsolute().<br>
605 Default: 1.0
606 */
607 RandomForestOptionsDeprec & trainingSetSizeProportional(double p)
608 {
609 vigra_precondition(p >= 0.0 && p <= 1.0,
610 "RandomForestOptionsDeprec::trainingSetSizeProportional(): proportion must be in [0, 1].");
611 if(training_set_size == 0) // otherwise, absolute size gets priority
612 training_set_proportion = p;
613 return *this;
614 }
615
616 /** Size of the training set for each tree.
617
618 If this option is set, it overrides the proportion set by
619 trainingSetSizeProportional(). When classes are sampled individually,
620 the number of examples is divided by the number of classes (rounded upwards)
621 to determine the number of examples drawn from every class.<br>
622 Default: <tt>0</tt> (determine size by proportion)
623 */
624 RandomForestOptionsDeprec & trainingSetSizeAbsolute(unsigned int s)
625 {
626 training_set_size = s;
627 if(s > 0)
628 training_set_proportion = 0.0;
629 return *this;
630 }
631
632 /** Are the classes sampled individually?
633
634 If \a s is <tt>false</tt> (the default), the training set for each tree is sampled
635 without considering class labels. Otherwise, samples are drawn from each
636 class independently. The latter is especially useful in connection
637 with the specification of an absolute training set size: then, the same number of
638 examples is drawn from every class. This can be used as a counter-measure when the
639 classes are very unbalanced in size.<br>
640 Default: <tt>false</tt>
641 */
642 RandomForestOptionsDeprec & sampleClassesIndividually(bool s)
643 {
644 sample_classes_individually = s;
645 return *this;
646 }
647
648 /** Number of examples required for a node to be split.
649
650 When the number of examples in a node is below this number, the node is not
651 split even if class separation is not yet perfect. Instead, the node returns
652 the proportion of each class (among the remaining examples) during the
653 prediction phase.<br>
654 Default: 1 (complete growing)
655 */
656 RandomForestOptionsDeprec & minSplitNodeSize(unsigned int n)
657 {
658 if(n == 0)
659 n = 1;
660 min_split_node_size = n;
661 return *this;
662 }
663
664 /** Use a weighted random forest.
665
666 This is usually used to penalize the errors for the minority class.
667 Weights must be convertible to <tt>double</tt>, and the array of weights
668 must contain as many entries as there are classes.<br>
669 Default: do not use weights
670 */
671 template <class WeightIterator>
672 RandomForestOptionsDeprec & weights(WeightIterator weights, unsigned int classCount)
673 {
674 class_weights.clear();
675 if(weights != 0)
676 class_weights.insert(weights, classCount);
677 return *this;
678 }
679
680 RandomForestOptionsDeprec & oobData(MultiArrayView<2, UInt8>& data)
681 {
682 oob_data =data;
683 return *this;
684 }
685
686 MultiArrayView<2, UInt8> oob_data;
687 ArrayVector<double> class_weights;
688 double training_set_proportion;
689 unsigned int mtry, min_split_node_size, training_set_size;
690 bool sample_with_replacement, sample_classes_individually;
691 unsigned int treeCount;
692};
693
694/*****************************************************************/
695/* */
696/* RandomForestDeprec */
697/* */
698/*****************************************************************/
699
700template <class ClassLabelType>
701class RandomForestDeprec
702{
703 public:
704 ArrayVector<ClassLabelType> classes_;
705 ArrayVector<detail::DecisionTreeDeprec> trees_;
706 MultiArrayIndex columnCount_;
707 RandomForestOptionsDeprec options_;
708
709 public:
710
711 //First two constructors are straight forward.
712 //they take either the iterators to an Array of Classlabels or the values
713 template<class ClassLabelIterator>
714 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
715 unsigned int treeCount = 255,
716 RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec())
717 : classes_(cl, cend),
718 trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
719 columnCount_(0),
720 options_(options)
721 {
722 vigra_precondition(options.training_set_proportion == 0.0 ||
723 options.training_set_size == 0,
724 "RandomForestOptionsDeprec: absolute and proportional training set sizes "
725 "cannot be specified at the same time.");
726 vigra_precondition(classes_.size() > 1,
727 "RandomForestOptionsDeprec::weights(): need at least two classes.");
728 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
729 "RandomForestOptionsDeprec::weights(): wrong number of classes.");
730 }
731
732 RandomForestDeprec(ClassLabelType const & c1, ClassLabelType const & c2,
733 unsigned int treeCount = 255,
734 RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec())
735 : classes_(2),
736 trees_(treeCount, detail::DecisionTreeDeprec(2)),
737 columnCount_(0),
738 options_(options)
739 {
740 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2,
741 "RandomForestOptionsDeprec::weights(): wrong number of classes.");
742 classes_[0] = c1;
743 classes_[1] = c2;
744 }
745 //This is esp. For the CrosValidator Class
746 template<class ClassLabelIterator>
747 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
748 RandomForestOptionsDeprec const & options )
749 : classes_(cl, cend),
750 trees_(options.treeCount , detail::DecisionTreeDeprec(classes_.size())),
751 columnCount_(0),
752 options_(options)
753 {
754
755 vigra_precondition(options.training_set_proportion == 0.0 ||
756 options.training_set_size == 0,
757 "RandomForestOptionsDeprec: absolute and proportional training set sizes "
758 "cannot be specified at the same time.");
759 vigra_precondition(classes_.size() > 1,
760 "RandomForestOptionsDeprec::weights(): need at least two classes.");
761 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
762 "RandomForestOptionsDeprec::weights(): wrong number of classes.");
763 }
764
765 //Not understood yet
766 //Does not use the options object but the columnCount object.
767 template<class ClassLabelIterator, class TreeIterator, class WeightIterator>
768 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
769 unsigned int treeCount, unsigned int columnCount,
770 TreeIterator trees, WeightIterator weights)
771 : classes_(cl, cend),
772 trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
773 columnCount_(columnCount)
774 {
775 for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights)
776 {
777 trees_[k].tree_ = *trees;
778 trees_[k].terminalWeights_ = *weights;
779 }
780 }
781
782 int featureCount() const
783 {
784 vigra_precondition(columnCount_ > 0,
785 "RandomForestDeprec::featureCount(): Random forest has not been trained yet.");
786 return columnCount_;
787 }
788
789 int labelCount() const
790 {
791 return classes_.size();
792 }
793
794 int treeCount() const
795 {
796 return trees_.size();
797 }
798
799 // loss == 0.0 means unweighted random forest
800 template <class U, class C, class Array, class Random>
801 double learn(MultiArrayView<2, U, C> const & features, Array const & labels,
802 Random const& random);
803
804 template <class U, class C, class Array>
805 double learn(MultiArrayView<2, U, C> const & features, Array const & labels)
806 {
807 RandomNumberGenerator<> generator(RandomSeed);
808 return learn(features, labels, generator);
809 }
810
811 template <class U, class C>
812 ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features) const;
813
814 template <class U, class C1, class T, class C2>
815 void predictLabels(MultiArrayView<2, U, C1> const & features,
816 MultiArrayView<2, T, C2> & labels) const
817 {
818 vigra_precondition(features.shape(0) == labels.shape(0),
819 "RandomForestDeprec::predictLabels(): Label array has wrong size.");
820 for(int k=0; k<features.shape(0); ++k)
821 labels(k,0) = predictLabel(rowVector(features, k));
822 }
823
824 template <class U, class C, class Iterator>
825 ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features,
826 Iterator priors) const;
827
828 template <class U, class C1, class T, class C2>
829 void predictProbabilities(MultiArrayView<2, U, C1> const & features,
830 MultiArrayView<2, T, C2> & prob) const;
831
832 template <class U, class C1, class T, class C2>
833 void predictNodes(MultiArrayView<2, U, C1> const & features,
834 MultiArrayView<2, T, C2> & NodeIDs) const;
835};
836
837template <class ClassLabelType>
838template <class U, class C1, class Array, class Random>
839double
840RandomForestDeprec<ClassLabelType>::learn(MultiArrayView<2, U, C1> const & features,
841 Array const & labels,
842 Random const& random)
843{
844 unsigned int classCount = classes_.size();
845 unsigned int m = rowCount(features);
846 unsigned int n = columnCount(features);
847 vigra_precondition((unsigned int)(m) == (unsigned int)labels.size(),
848 "RandomForestDeprec::learn(): Label array has wrong size.");
849
850 vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement,
851 "RandomForestDeprec::learn(): Requested training set size exceeds total number of examples.");
852
853 MultiArrayIndex mtry = (options_.mtry == 0)
854 ? int(std::floor(std::sqrt(double(n)) + 0.5))
855 : options_.mtry;
856
857 vigra_precondition(mtry <= (MultiArrayIndex)n,
858 "RandomForestDeprec::learn(): mtry must be less than number of features.");
859
860 MultiArrayIndex msamples = options_.training_set_size;
861 if(options_.sample_classes_individually)
862 msamples = int(std::ceil(double(msamples) / classCount));
863
864 ArrayVector<int> intLabels(m), classExampleCounts(classCount);
865
866 // verify the input labels
867 int minClassCount;
868 {
869 typedef std::map<ClassLabelType, int > LabelChecker;
870 typedef typename LabelChecker::iterator LabelCheckerIterator;
872 for(unsigned int k=0; k<classCount; ++k)
873 labelChecker[classes_[k]] = k;
874
875 for(unsigned int k=0; k<m; ++k)
876 {
878 vigra_precondition(found != labelChecker.end(),
879 "RandomForestDeprec::learn(): Unknown class label encountered.");
880 intLabels[k] = found->second;
882 }
884 vigra_precondition(minClassCount > 0,
885 "RandomForestDeprec::learn(): At least one class is missing in the training set.");
886 if(msamples > 0 && options_.sample_classes_individually &&
887 !options_.sample_with_replacement)
888 {
889 vigra_precondition(msamples <= minClassCount,
890 "RandomForestDeprec::learn(): Too few examples in smallest class to reach "
891 "requested training set size.");
892 }
893 }
894 columnCount_ = n;
895 ArrayVector<int> indices(m);
896 for(unsigned int k=0; k<m; ++k)
897 indices[k] = k;
898
899 if(options_.sample_classes_individually)
900 {
901 detail::RandomForestDeprecLabelSorter<ArrayVector<int> > sorter(intLabels);
902 std::sort(indices.begin(), indices.end(), sorter);
903 }
904
905 ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m);
906
908 //std::cerr << "Learning a RF \n";
909 for(unsigned int k=0; k<trees_.size(); ++k)
910 {
911 //std::cerr << "Learning tree " << k << " ...\n";
912
913 ArrayVector<int> trainingSet;
914 usedIndices.init(0);
915
916 if(options_.sample_classes_individually)
917 {
918 int first = 0;
919 for(unsigned int l=0; l<classCount; ++l)
920 {
921 int lc = classExampleCounts[l];
922 int lsamples = (msamples == 0)
923 ? int(std::ceil(options_.training_set_proportion*lc))
924 : msamples;
925
926 if(options_.sample_with_replacement)
927 {
928 for(int ll=0; ll<lsamples; ++ll)
929 {
930 trainingSet.push_back(indices[first+randint(lc)]);
931 ++usedIndices[trainingSet.back()];
932 }
933 }
934 else
935 {
936 for(int ll=0; ll<lsamples; ++ll)
937 {
938 std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]);
939 trainingSet.push_back(indices[first+ll]);
940 ++usedIndices[trainingSet.back()];
941 }
942 //std::sort(indices.begin(), indices.begin()+lsamples);
943 }
944 first += lc;
945 }
946 }
947 else
948 {
949 if(msamples == 0)
950 msamples = int(std::ceil(options_.training_set_proportion*m));
951
952 if(options_.sample_with_replacement)
953 {
954 for(int l=0; l<msamples; ++l)
955 {
956 trainingSet.push_back(indices[randint(m)]);
957 ++usedIndices[trainingSet.back()];
958 }
959 }
960 else
961 {
962 for(int l=0; l<msamples; ++l)
963 {
964 std::swap(indices[l], indices[l+randint(m-l)/*oikas*/]);
965 trainingSet.push_back(indices[l]);
966 ++usedIndices[trainingSet.back()];
967 }
968
969
970 }
971
972 }
973 trees_[k].learn(features, intLabels,
974 trainingSet.begin(), trainingSet.size(),
975 options_.featuresPerNode(mtry), randint);
976// for(unsigned int l=0; l<m; ++l)
977// {
978// if(!usedIndices[l])
979// {
980// ++oobCount[l];
981// if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
982// ++oobErrorCount[l];
983// }
984// }
985
986 for(unsigned int l=0; l<m; ++l)
987 {
988 if(!usedIndices[l])
989 {
990 ++oobCount[l];
991 if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
992 {
993 ++oobErrorCount[l];
994 if(options_.oob_data.data() != 0)
995 options_.oob_data(l, k) = 2;
996 }
997 else if(options_.oob_data.data() != 0)
998 {
999 options_.oob_data(l, k) = 1;
1000 }
1001 }
1002 }
1003 // TODO: default value for oob_data
1004 // TODO: implement variable importance
1005 //if(!options_.sample_with_replacement){
1006 //std::cerr << "done\n";
1007 //trees_[k].print(std::cerr);
1008 #ifdef VIGRA_RF_VERBOSE
1009 trees_[k].printStatistics(std::cerr);
1010 #endif
1011 }
1012 double oobError = 0.0;
1013 int totalOobCount = 0;
1014 for(unsigned int l=0; l<m; ++l)
1015 if(oobCount[l])
1016 {
1017 oobError += double(oobErrorCount[l]) / oobCount[l];
1018 ++totalOobCount;
1019 }
1020 return oobError / totalOobCount;
1021}
1022
1023template <class ClassLabelType>
1024template <class U, class C>
1025ClassLabelType
1026RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features) const
1027{
1028 vigra_precondition(columnCount(features) >= featureCount(),
1029 "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
1030 vigra_precondition(rowCount(features) == 1,
1031 "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
1032 Matrix<double> prob(1, classes_.size());
1033 predictProbabilities(features, prob);
1034 return classes_[argMax(prob)];
1035}
1036
1037
1038//Same thing as above with priors for each label !!!
1039template <class ClassLabelType>
1040template <class U, class C, class Iterator>
1041ClassLabelType
1042RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features,
1043 Iterator priors) const
1044{
1045 using namespace functor;
1046 vigra_precondition(columnCount(features) >= featureCount(),
1047 "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
1048 vigra_precondition(rowCount(features) == 1,
1049 "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
1050 Matrix<double> prob(1,classes_.size());
1051 predictProbabilities(features, prob);
1052 std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2());
1053 return classes_[argMax(prob)];
1054}
1055
1056template <class ClassLabelType>
1057template <class U, class C1, class T, class C2>
1058void
1059RandomForestDeprec<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> const & features,
1060 MultiArrayView<2, T, C2> & prob) const
1061{
1062
1063 //Features are n xp
1064 //prob is n x NumOfLabel probability for each feature in each class
1065
1066 vigra_precondition(rowCount(features) == rowCount(prob),
1067 "RandomForestDeprec::predictProbabilities(): Feature matrix and probability matrix size mismatch.");
1068
1069 // num of features must be bigger than num of features in Random forest training
1070 // but why bigger?
1071 vigra_precondition(columnCount(features) >= featureCount(),
1072 "RandomForestDeprec::predictProbabilities(): Too few columns in feature matrix.");
1073 vigra_precondition(columnCount(prob) == (MultiArrayIndex)labelCount(),
1074 "RandomForestDeprec::predictProbabilities(): Probability matrix must have as many columns as there are classes.");
1075
1076 //Classify for each row.
1077 for(int row=0; row < rowCount(features); ++row)
1078 {
1079 //contains the weights returned by a single tree???
1080 //thought that one tree has only one vote???
1081 //Pruning???
1083
1084 //totalWeight == totalVoteCount!
1085 double totalWeight = 0.0;
1086
1087 //Set each VoteCount = 0 - prob(row,l) contains vote counts until
1088 //further normalisation
1089 for(unsigned int l=0; l<classes_.size(); ++l)
1090 prob(row, l) = 0.0;
1091
1092 //Let each tree classify...
1093 for(unsigned int k=0; k<trees_.size(); ++k)
1094 {
1095 //get weights predicted by single tree
1096 weights = trees_[k].predict(rowVector(features, row));
1097
1098 //update votecount.
1099 for(unsigned int l=0; l<classes_.size(); ++l)
1100 {
1101 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]);
1102 //every weight in totalWeight.
1103 totalWeight += weights[l];
1104 }
1105 }
1106
1107 //Normalise votes in each row by total VoteCount (totalWeight
1108 for(unsigned int l=0; l<classes_.size(); ++l)
1109 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1110 }
1111}
1112
1113
1114template <class ClassLabelType>
1115template <class U, class C1, class T, class C2>
1116void
1117RandomForestDeprec<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> const & features,
1118 MultiArrayView<2, T, C2> & NodeIDs) const
1119{
1120 vigra_precondition(columnCount(features) >= featureCount(),
1121 "RandomForestDeprec::getNodesRF(): Too few columns in feature matrix.");
1122 vigra_precondition(rowCount(features) <= rowCount(NodeIDs),
1123 "RandomForestDeprec::getNodesRF(): Too few rows in NodeIds matrix");
1124 vigra_precondition(columnCount(NodeIDs) >= treeCount(),
1125 "RandomForestDeprec::getNodesRF(): Too few columns in NodeIds matrix.");
1126 NodeIDs.init(0);
1127 for(unsigned int k=0; k<trees_.size(); ++k)
1128 {
1129 for(int row=0; row < rowCount(features); ++row)
1130 {
1131 NodeIDs(row,k) = trees_[k].leafID(rowVector(features, row));
1132 }
1133 }
1134}
1135
1136//@}
1137
1138} // namespace vigra
1139
1140
1141#endif // VIGRA_RANDOM_FOREST_HXX
1142
RGBValue()
Definition rgbvalue.hxx:209
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
Iterator argMin(Iterator first, Iterator last)
Find the minimum element in a sequence.
Definition algorithm.hxx:68
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1