jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
normal_herd.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2011 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "normal_herd.hpp"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <string>
22 
23 #include "jubatus/util/concurrent/lock.h"
24 #include "classifier_util.hpp"
25 #include "../common/exception.hpp"
26 
27 using std::string;
28 
29 namespace jubatus {
30 namespace core {
31 namespace classifier {
32 
34  : linear_classifier(storage) {
36 }
37 
39  const classifier_config& config,
40  storage_ptr storage)
41  : linear_classifier(storage),
42  config_(config) {
43 
44  if (!(0.f < config.regularization_weight)) {
45  throw JUBATUS_EXCEPTION(
46  common::invalid_parameter("0.0 < regularization_weight"));
47  }
48 }
49 
50 void normal_herd::train(const common::sfv_t& sfv, const string& label) {
51  check_touchable(label);
52 
53  string incorrect_label;
54  float variance = 0.f;
55  float margin = -calc_margin_and_variance(sfv, label, incorrect_label,
56  variance);
57  if (margin >= 1.f) {
58  storage_->register_label(label);
59  return;
60  }
61  update(sfv, margin, variance, label, incorrect_label);
62 }
63 
65  const common::sfv_t& sfv,
66  float margin,
67  float variance,
68  const string& pos_label,
69  const string& neg_label) {
70  util::concurrent::scoped_lock lk(storage_->get_lock());
71  for (common::sfv_t::const_iterator it = sfv.begin(); it != sfv.end(); ++it) {
72  const string& feature = it->first;
73  float val = it->second;
75  storage_->get2_nolock(feature, ret);
76 
77  storage::val2_t pos_val(0.f, 1.f);
78  storage::val2_t neg_val(0.f, 1.f);
79  ClassifierUtil::get_two(ret, pos_label, neg_label, pos_val, neg_val);
80 
81  float val_covariance_pos = val * pos_val.v2;
82  float val_covariance_neg = val * neg_val.v2;
83 
84  const float C = config_.regularization_weight;
85  storage_->set2_nolock(
86  feature,
87  pos_label,
89  pos_val.v1
90  + (1.f - margin) * val_covariance_pos
91  / (variance + 1.f / C),
92  1.f
93  / ((1.f / pos_val.v2) + (2 * C + C * C * variance)
94  * val * val)));
95  if (neg_label != "") {
96  storage_->set2_nolock(
97  feature,
98  neg_label,
100  neg_val.v1
101  - (1.f - margin) * val_covariance_neg
102  / (variance + 1.f / C),
103  1.f
104  / ((1.f / neg_val.v2) + (2 * C + C * C * variance)
105  * val * val)));
106  }
107  }
108  touch(pos_label);
109 }
110 
111 std::string normal_herd::name() const {
112  return string("normal_herd");
113 }
114 
115 } // namespace classifier
116 } // namespace core
117 } // namespace jubatus
jubatus::util::lang::shared_ptr< jubatus::core::storage::storage_base > storage_ptr
void update(const common::sfv_t &sfv, float margin, float variance, const std::string &pos_label, const std::string &neg_label)
Definition: normal_herd.cpp:64
static void get_two(const T &t, const std::string &label1, const std::string &label2, U &u1, U &u2)
float calc_margin_and_variance(const common::sfv_t &sfv, const std::string &label, std::string &incorrect_label, float &variance) const
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
void check_touchable(const std::string &label)
std::vector< std::pair< std::string, float > > sfv_t
Definition: type.hpp:29
void train(const common::sfv_t &fv, const std::string &label)
Definition: normal_herd.cpp:50
std::vector< std::pair< std::string, val2_t > > feature_val2_t