jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
linear_function_mixer.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2013 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
18 
19 #include <string>
20 #include <algorithm>
21 
22 #include "jubatus/util/lang/bind.h"
23 
24 using std::string;
25 using jubatus::util::lang::bind;
26 using jubatus::util::lang::_1;
27 using jubatus::util::lang::_2;
28 
32 
33 namespace jubatus {
34 namespace core {
35 namespace framework {
36 
37 namespace {
38 
39 val3_t mix_val3(double w1, double w2, const val3_t& lhs, const val3_t& rhs) {
40  return val3_t(
41  (w1 * lhs.v1 + w2 * rhs.v1) / (w1 + w2),
42  std::min(lhs.v2, rhs.v2),
43  (w1 * lhs.v3 + w2 * rhs.v3) / (w1 + w2));
44 }
45 
46 feature_val3_t mix_feature(
47  double w1,
48  double w2,
49  const feature_val3_t& lhs,
50  const feature_val3_t& rhs) {
51  val3_t def(0, 1, 0);
52  feature_val3_t ret(lhs);
53  storage::detail::binop(ret, rhs, bind(mix_val3, w1, w2, _1, _2), def);
54  return ret;
55 }
56 
57 struct internal_diff_object : diff_object_raw {
58  void convert_binary(packer& pk) const {
59  pk.pack(diff_);
60  }
61 
62  diffv diff_;
63 };
64 
65 } // namespace
66 
67 void linear_function_mixer::mix(const diffv& lhs, diffv& mixed) const {
68  if (lhs.v.expect_version == mixed.v.expect_version) {
69  features3_t l(lhs.v.diff);
70  const features3_t& r(mixed.v.diff);
72  l,
73  r,
74  bind(mix_feature, lhs.count, mixed.count, _1, _2));
75  mixed.v.diff.swap(l);
76  mixed.count = lhs.count + mixed.count;
77  } else if (lhs.v.expect_version > mixed.v.expect_version) {
78  mixed = lhs;
79  }
80 }
81 
83  diff.count = 1; // TODO(kuenishi) mixer_->get_count();
84  get_model()->get_diff(diff.v);
85 }
86 
88  if (label_unlearner_) {
89  for (size_t i = 0; i < v.v.diff.size(); ++i) {
90  const feature_val3_t& classes = v.v.diff[i].second;
91  for (size_t j = 0; j < classes.size(); ++j) {
92  // ignore error returned by touch
93  label_unlearner_->touch(classes[j].first);
94  }
95  }
96 
97  features3_t parameters(v.v.diff.size());
98  for (size_t i = 0; i < v.v.diff.size(); ++i) {
99  parameters[i].first = v.v.diff[i].first;
100 
101  // Copy weights of classes except unlearned classes.
102  const feature_val3_t& source_classes = v.v.diff[i].second;
103  feature_val3_t& target_classes = parameters[i].second;
104 
105  target_classes.reserve(source_classes.size());
106  for (size_t j = 0; j < source_classes.size(); ++j) {
107  if (label_unlearner_->exists_in_memory(source_classes[j].first)) {
108  target_classes.push_back(source_classes[j]);
109  }
110  }
111  }
112 
113  storage::diff_t unlearned_diff;
114  std::swap(unlearned_diff.diff, parameters);
115  unlearned_diff.expect_version = v.v.expect_version;
116 
117  return get_model()->set_average_and_clear_diff(unlearned_diff);
118  } else {
119  return get_model()->set_average_and_clear_diff(v.v);
120  }
121 }
122 
124  const msgpack::object& obj) const {
125  internal_diff_object* diff = new internal_diff_object;
126  diff_object diff_obj(diff);
127  obj.convert(&diff->diff_);
128  return diff_obj;
129 }
130 
132  const msgpack::object& obj,
133  diff_object ptr) const {
134  diffv diff;
135  internal_diff_object* diff_obj =
136  dynamic_cast<internal_diff_object*>(ptr.get());
137  if (!diff_obj) {
138  throw JUBATUS_EXCEPTION(
139  core::common::exception::runtime_error("bad diff_object"));
140  }
141  obj.convert(&diff);
142  mix(diff, diff_obj->diff_);
143 }
144 
146  diffv diff;
147  get_diff(diff);
148  pk.pack(diff);
149 }
150 
152  internal_diff_object* diff_obj =
153  dynamic_cast<internal_diff_object*>(ptr.get());
154  if (!diff_obj) {
155  throw JUBATUS_EXCEPTION(
156  core::common::exception::runtime_error("bad diff_object"));
157  }
158  return put_diff(diff_obj->diff_);
159 }
160 
161 } // namespace framework
162 } // namespace core
163 } // namespace jubatus
void mix(const diffv &lhs, diffv &mixed) const
std::vector< std::pair< std::string, feature_val3_t > > features3_t
jubatus::util::lang::shared_ptr< diff_object_raw > diff_object
diffv diff_
diff_object convert_diff_object(const msgpack::object &) const
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
void swap(weighted_point &p1, weighted_point &p2)
Definition: types.hpp:47
msgpack::packer< jubatus_packer > packer
Definition: bandit_base.hpp:31
std::vector< T > v(size)
std::vector< std::pair< std::string, val3_t > > feature_val3_t
jubatus::util::lang::shared_ptr< unlearner::unlearner_base > label_unlearner_
std::vector< std::pair< std::string, E > > & binop(std::vector< std::pair< std::string, E > > &lhs, std::vector< std::pair< std::string, E > > rhs, F f, E default_value=E())