jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
arow.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2011 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "arow.hpp"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <string>
22 
23 #include "jubatus/util/concurrent/lock.h"
24 #include "classifier_util.hpp"
25 #include "../common/exception.hpp"
26 
27 using std::string;
28 
29 namespace jubatus {
30 namespace core {
31 namespace classifier {
32 
34  : linear_classifier(storage) {
35 }
36 
38  const classifier_config& config,
39  storage_ptr storage)
40  : linear_classifier(storage),
41  config_(config) {
42 
43  if (!(0.f < config.regularization_weight)) {
44  throw JUBATUS_EXCEPTION(
45  common::invalid_parameter("0.0 < regularization_weight"));
46  }
47 }
48 
49 void arow::train(const common::sfv_t& sfv, const string& label) {
50  check_touchable(label);
51 
52  string incorrect_label;
53  float variance = 0.f;
54  float margin = -calc_margin_and_variance(sfv, label, incorrect_label,
55  variance);
56  if (margin >= 1.f) {
57  storage_->register_label(label);
58  return;
59  }
60 
61  float beta = 1.f / (variance + 1.f / config_.regularization_weight);
62  float alpha = (1.f - margin) * beta; // max(0, 1 - margin) = 1 - margin
63  update(sfv, alpha, beta, label, incorrect_label);
64 }
65 
67  const common::sfv_t& sfv,
68  float alpha,
69  float beta,
70  const std::string& pos_label,
71  const std::string& neg_label) {
72  util::concurrent::scoped_lock lk(storage_->get_lock());
73  for (common::sfv_t::const_iterator it = sfv.begin(); it != sfv.end(); ++it) {
74  const string& feature = it->first;
75  float val = it->second;
77  storage_->get2_nolock(feature, ret);
78 
79  storage::val2_t pos_val(0.f, 1.f);
80  storage::val2_t neg_val(0.f, 1.f);
81  ClassifierUtil::get_two(ret, pos_label, neg_label, pos_val, neg_val);
82 
83  storage_->set2_nolock(
84  feature,
85  pos_label,
87  pos_val.v1 + alpha * pos_val.v2 * val,
88  pos_val.v2 - beta * pos_val.v2 * pos_val.v2 * val * val));
89  if (neg_label != "") {
90  storage_->set2_nolock(
91  feature,
92  neg_label,
94  neg_val.v1 - alpha * neg_val.v2 * val,
95  neg_val.v2 - beta * neg_val.v2 * neg_val.v2 * val * val));
96  }
97  }
98  touch(pos_label);
99 }
100 
101 string arow::name() const {
102  return string("arow");
103 }
104 
105 } // namespace classifier
106 } // namespace core
107 } // namespace jubatus
jubatus::util::lang::shared_ptr< jubatus::core::storage::storage_base > storage_ptr
static void get_two(const T &t, const std::string &label1, const std::string &label2, U &u1, U &u2)
arow(storage_ptr storage)
Definition: arow.cpp:33
float calc_margin_and_variance(const common::sfv_t &sfv, const std::string &label, std::string &incorrect_label, float &variance) const
void train(const common::sfv_t &fv, const std::string &label)
Definition: arow.cpp:49
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
void update(const common::sfv_t &fv, float alpha, float beta, const std::string &pos_label, const std::string &neg_label)
Definition: arow.cpp:66
void check_touchable(const std::string &label)
classifier_config config_
Definition: arow.hpp:41
std::vector< std::pair< std::string, float > > sfv_t
Definition: type.hpp:29
std::vector< std::pair< std::string, val2_t > > feature_val2_t
std::string name() const
Definition: arow.cpp:101