jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
confidence_weighted.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2011 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "confidence_weighted.hpp"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <string>
22 
23 #include "jubatus/util/concurrent/lock.h"
24 #include "classifier_util.hpp"
25 #include "../common/exception.hpp"
26 
27 using std::string;
28 
29 namespace jubatus {
30 namespace core {
31 namespace classifier {
32 
34  : linear_classifier(storage) {
35 }
36 
38  const classifier_config& config,
39  storage_ptr storage)
40  : linear_classifier(storage),
41  config_(config) {
42 
43  if (!(0.f < config.regularization_weight)) {
44  throw JUBATUS_EXCEPTION(
45  common::invalid_parameter("0.0 < regularization_weight"));
46  }
47 }
48 
49 void confidence_weighted::train(const common::sfv_t& sfv, const string& label) {
50  check_touchable(label);
51 
52  const float C = config_.regularization_weight;
53  string incorrect_label;
54  float variance = 0.f;
55  float margin = -calc_margin_and_variance(sfv, label, incorrect_label,
56  variance);
57  float b = 1.f + 2 * C * margin;
58  float gamma = -b + std::sqrt(b * b - 8 * C * (margin - C * variance));
59 
60  if (gamma <= 0.f) {
61  storage_->register_label(label);
62  return;
63  }
64  gamma /= 4 * C * variance;
65  update(sfv, gamma, label, incorrect_label);
66 }
67 
69  const common::sfv_t& sfv,
70  float step_width,
71  const string& pos_label,
72  const string& neg_label) {
73  util::concurrent::scoped_lock lk(storage_->get_lock());
74  for (common::sfv_t::const_iterator it = sfv.begin(); it != sfv.end(); ++it) {
75  const string& feature = it->first;
76  float val = it->second;
78  storage_->get2_nolock(feature, val2);
79 
80  storage::val2_t pos_val(0.f, 1.f);
81  storage::val2_t neg_val(0.f, 1.f);
82  ClassifierUtil::get_two(val2, pos_label, neg_label, pos_val, neg_val);
83 
84  const float C = config_.regularization_weight;
85  float covar_pos_step = 2.f * step_width * val * val * C;
86  float covar_neg_step = 2.f * step_width * val * val * C;
87 
88  storage_->set2_nolock(
89  feature,
90  pos_label,
91  storage::val2_t(pos_val.v1 + step_width * pos_val.v2 * val,
92  1.f / (1.f / pos_val.v2 + covar_pos_step)));
93  if (neg_label != "") {
94  storage_->set2_nolock(
95  feature,
96  neg_label,
97  storage::val2_t(neg_val.v1 - step_width * neg_val.v2 * val,
98  1.f / (1.f / neg_val.v2 + covar_neg_step)));
99  }
100  }
101  touch(pos_label);
102 }
103 
105  return string("confidence_weighted");
106 }
107 
108 } // namespace classifier
109 } // namespace core
110 } // namespase jubatus
jubatus::util::lang::shared_ptr< jubatus::core::storage::storage_base > storage_ptr
void train(const common::sfv_t &fv, const std::string &label)
static void get_two(const T &t, const std::string &label1, const std::string &label2, U &u1, U &u2)
void update(const common::sfv_t &fv, float step_weigth, const std::string &pos_label, const std::string &neg_label)
float calc_margin_and_variance(const common::sfv_t &sfv, const std::string &label, std::string &incorrect_label, float &variance) const
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
void check_touchable(const std::string &label)
std::vector< std::pair< std::string, float > > sfv_t
Definition: type.hpp:29
std::vector< std::pair< std::string, val2_t > > feature_val2_t