jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
driver.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2014 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "driver.hpp"
18 #include <algorithm>
19 #include <numeric>
20 #include <string>
21 #include <set>
22 #include <vector>
23 
24 using std::find;
25 using std::set;
26 using std::string;
27 using std::vector;
34 
35 namespace jubatus {
36 namespace core {
37 namespace driver {
38 
39 namespace {
40 
41 struct internal_diff_object : diff_object_raw {
42  explicit internal_diff_object(const vector<diff_object>& diffs)
43  : diffs_(diffs) {
44  }
45 
46  void convert_binary(packer& pk) const {
47  pk.pack_array(diffs_.size());
48  for (size_t i = 0; i < diffs_.size(); i++) {
49  diffs_[i]->convert_binary(pk);
50  }
51  }
52 
53  vector<diff_object> diffs_;
54 };
55 
56 template <class T>
57 size_t count_mixable(const vector<mixable*>& mixables) {
58  size_t count = 0;
59  for (size_t i = 0; i < mixables.size(); i++) {
60  if (dynamic_cast<T*>(mixables[i])) {
61  count++;
62  }
63  }
64 
65  return count;
66 }
67 
68 } // namespace
69 
71  set<string> s;
72 
73  if (count_mixable<linear_mixable>(mixables_) > 0) {
74  s.insert("linear_mixable");
75  }
76 
77  if (count_mixable<push_mixable>(mixables_) > 0) {
78  s.insert("push_mixable");
79  }
80 
81  return s;
82 }
83 
85  if (find(mixables_.begin(), mixables_.end(), mixable) == mixables_.end()) {
86  mixables_.push_back(mixable);
87  }
88 }
89 
91  const msgpack::object& o) const {
92  if (o.type != msgpack::type::ARRAY ||
93  o.via.array.size != count_mixable<linear_mixable>(mixables_)) {
94  throw JUBATUS_EXCEPTION(
95  core::common::exception::runtime_error("conversion failed"));
96  }
97 
98  vector<diff_object> diffs;
99  for (size_t i = 0; i < mixables_.size(); i++) {
100  const linear_mixable* mixable =
101  dynamic_cast<const linear_mixable*>(mixables_[i]);
102  if (!mixable) {
103  continue;
104  }
105  diffs.push_back(mixable->convert_diff_object(o.via.array.ptr[i]));
106  }
107 
108  return diff_object(new internal_diff_object(diffs));
109 }
110 
112  const msgpack::object& o,
113  diff_object ptr) const {
114  if (o.type != msgpack::type::ARRAY ||
115  o.via.array.size != count_mixable<linear_mixable>(mixables_)) {
116  throw JUBATUS_EXCEPTION(
118  }
119 
120  internal_diff_object* diff_obj =
121  dynamic_cast<internal_diff_object*>(ptr.get());
122  if (!diff_obj) {
123  throw JUBATUS_EXCEPTION(
124  core::common::exception::runtime_error("bad diff_object"));
125  }
126 
127  if (mixables_.size() != diff_obj->diffs_.size()) {
128  throw JUBATUS_EXCEPTION(
129  core::common::exception::runtime_error("diff size is wrong"));
130  }
131 
132  for (size_t i = 0; i < mixables_.size(); i++) {
133  const linear_mixable* mixable =
134  dynamic_cast<const linear_mixable*>(mixables_[i]);
135  if (!mixable) {
136  continue;
137  }
138  mixable->mix(o.via.array.ptr[i], diff_obj->diffs_[i]);
139  }
140 }
141 
143  pk.pack_array(count_mixable<linear_mixable>(mixables_));
144  for (size_t i = 0; i < mixables_.size(); i++) {
145  const linear_mixable* mixable =
146  dynamic_cast<const linear_mixable*>(mixables_[i]);
147  if (!mixable) {
148  continue;
149  }
150  mixable->get_diff(pk);
151  }
152 }
153 
155  internal_diff_object* diff_obj =
156  dynamic_cast<internal_diff_object*>(obj.get());
157  if (!diff_obj) {
158  throw JUBATUS_EXCEPTION(
159  core::common::exception::runtime_error("bad diff_object"));
160  }
161 
162  if (count_mixable<linear_mixable>(mixables_) != diff_obj->diffs_.size()) {
163  throw JUBATUS_EXCEPTION(
164  core::common::exception::runtime_error("diff size is wrong"));
165  }
166 
167  bool success = true;
168  for (size_t i = 0; i < mixables_.size(); i++) {
169  linear_mixable* mixable = dynamic_cast<linear_mixable*>(mixables_[i]);
170  if (!mixable) {
171  continue;
172  }
173  success = mixable->put_diff(diff_obj->diffs_[i]) && success;
174  }
175 
176  return success;
177 }
178 
180  pk.pack_array(count_mixable<push_mixable>(mixables_));
181  for (size_t i = 0; i < mixables_.size(); i++) {
182  const push_mixable* mixable =
183  dynamic_cast<const push_mixable*>(mixables_[i]);
184  if (!mixable) {
185  continue;
186  }
187  mixable->get_argument(pk);
188  }
189 }
190 
192  const msgpack::object& arg,
193  packer& pk) const {
194  if (arg.type != msgpack::type::ARRAY ||
195  arg.via.array.size != count_mixable<push_mixable>(mixables_)) {
196  throw JUBATUS_EXCEPTION(
197  core::common::exception::runtime_error("pull array failed"));
198  }
199 
200  pk.pack_array(count_mixable<push_mixable>(mixables_));
201  for (size_t i = 0, obj_index = 0; i < mixables_.size(); i++) {
202  const push_mixable* mixable =
203  dynamic_cast<const push_mixable*>(mixables_[i]);
204  if (!mixable) {
205  continue;
206  }
207  mixable->pull(arg.via.array.ptr[obj_index], pk);
208  obj_index++;
209  }
210 }
211 
212 void driver_base::mixable_holder::push(const msgpack::object& o) {
213  if (o.type != msgpack::type::ARRAY ||
214  o.via.array.size != count_mixable<push_mixable>(mixables_)) {
215  throw JUBATUS_EXCEPTION(
217  }
218 
219  for (size_t i = 0, obj_index = 0; i < mixables_.size(); i++) {
220  push_mixable* mixable = dynamic_cast<push_mixable*>(mixables_[i]);
221  if (!mixable) {
222  continue;
223  }
224  mixable->push(o.via.array.ptr[obj_index]);
225  obj_index++;
226  }
227 }
228 
229 std::vector<storage::version>
231  std::vector<storage::version> ret;
232  for (size_t i = 0; i < mixables_.size(); ++i) {
233  ret.push_back(mixables_[i]->get_version());
234  }
235  return ret;
236 }
237 
238 std::vector<storage::version> driver_base::get_versions() const {
239  return holder_.get_versions();
240 }
241 
243  holder_.register_mixable(mixable);
244 }
245 
246 
247 } // namespace driver
248 } // namespace core
249 } // namespace jubatus
virtual void mix(const msgpack::object &obj, diff_object) const =0
virtual void get_argument(packer &) const =0
void get_diff(framework::packer &) const
Definition: driver.cpp:142
void pull(const msgpack::object &arg, framework::packer &) const
Definition: driver.cpp:191
jubatus::util::lang::shared_ptr< diff_object_raw > diff_object
void get_argument(framework::packer &) const
Definition: driver.cpp:179
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
virtual bool put_diff(const diff_object &obj)=0
bool put_diff(const framework::diff_object &obj)
Definition: driver.cpp:154
std::vector< storage::version > get_versions() const
Definition: driver.cpp:230
void register_mixable(framework::mixable *mixable)
Definition: driver.cpp:84
framework::diff_object convert_diff_object(const msgpack::object &) const
Definition: driver.cpp:90
virtual diff_object convert_diff_object(const msgpack::object &) const =0
msgpack::packer< jubatus_packer > packer
Definition: bandit_base.hpp:31
void mix(const msgpack::object &obj, framework::diff_object) const
Definition: driver.cpp:111
std::vector< storage::version > get_versions() const
Definition: driver.cpp:238
virtual void push(const msgpack::object &)=0
std::set< std::string > mixables() const
Definition: driver.cpp:70
void register_mixable(framework::mixable *mixable)
Definition: driver.cpp:242
virtual void get_diff(packer &) const =0
virtual void pull(const msgpack::object &arg, packer &) const =0
vector< diff_object > diffs_
Definition: driver.cpp:53