jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
burst_result.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2014 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "burst_result.hpp"
18 
19 #include <vector>
20 #include <utility>
21 
22 #include "input_window.hpp"
23 #include "result_window.hpp"
24 #include "engine.hpp"
25 #include "window_intersection.hpp"
26 
27 using jubatus::util::lang::shared_ptr;
28 
29 namespace jubatus {
30 namespace core {
31 namespace burst {
32 
33 namespace {
34 
35 int accumulate_d_vec(const burst_result& r) {
36  const std::vector<batch_result>& batches = r.get_batches();
37  int result = 0;
38  for (size_t i = 0; i < batches.size(); ++i) {
39  result += batches[i].d;
40  }
41  return result;
42 }
43 
44 } // namespace
45 
47 }
48 
50  const input_window& input,
51  double scaling_param,
52  double gamma,
53  double costcut_threshold,
54  const burst_result& prev_result,
55  int max_reuse_batches) {
56  const std::vector<batch_input>& input_batches = input.get_batches();
57  const size_t n = input.get_batch_size();
58  const int max_reuse = (std::min)(max_reuse_batches, static_cast<int>(n));
59 
60  // make vectors for engine
61  std::vector<uint32_t> d_vec, r_vec;
62  std::vector<double> burst_weights;
63  d_vec.reserve(n);
64  r_vec.reserve(n);
65  burst_weights.reserve(n);
66  for (size_t i = 0; i < n; ++i) {
67  d_vec.push_back(input_batches[i].d);
68  r_vec.push_back(input_batches[i].r);
69  burst_weights.push_back(-1); // uncalculated
70  }
71 
72  // reuse batch weights
73  if (prev_result.p_) {
74  const result_window& prev = *prev_result.p_;
75  if (prev.get_start_pos() <= input.get_start_pos()) {
76  const std::pair<int, int> intersection = get_intersection(prev, input);
77  const std::vector<batch_result>& prev_results = prev.get_batches();
78  for (int i = 0, j = intersection.first;
79  i < max_reuse && j < intersection.second;
80  ++i, ++j) {
81  burst_weights[i] = prev_results[j].burst_weight;
82  }
83  }
84  }
85 
86  // doit
87  burst::burst_detect(d_vec, r_vec, burst_weights,
88  scaling_param, gamma, costcut_threshold);
89 
90  // store result
91  p_.reset(new result_window(input, burst_weights));
92 }
93 
95  : p_(new result_window(src)) {
96 }
97 
98 bool burst_result::is_valid() const {
99  return p_.get() != NULL;
100 }
101 
102 const double burst_result::invalid_pos = -1;
103 
105  return p_ ? p_->get_start_pos() : invalid_pos;
106 }
108  return p_ ? p_->get_end_pos() : invalid_pos;
109 }
110 bool burst_result::contains(double pos) const {
111  return p_ && p_->contains(pos);
112 }
113 
115  return p_ ? p_->get_batch_size() : 0;
116 }
118  return p_ ? p_->get_batch_interval() : 1;
119 }
121  return p_ ? p_->get_all_interval() : 0;
122 }
123 
125  if (!p_) {
126  return false;
127  }
128  double pos0 = p_->get_start_pos();
129  if (pos0 < pos) {
130  return !window_position_near(pos0, pos, p_->get_batch_interval());
131  } else {
132  return false;
133  }
134 }
136  if (!p_) {
137  return false;
138  }
139  double pos0 = p_->get_start_pos();
140  if (pos0 > pos) {
141  return !window_position_near(pos0, pos, p_->get_batch_interval());
142  } else {
143  return false;
144  }
145 }
146 bool burst_result::has_same_start_pos_to(double pos) const {
147  if (!p_) {
148  return false;
149  }
150  double pos0 = p_->get_start_pos();
151  return window_position_near(pos0, pos, p_->get_batch_interval());
152 }
153 
155  if (!p_ || !x.p_) {
156  return false;
157  }
158  double interval_x = x.p_->get_batch_interval();
159  return intersection_helper(*p_).has_batch_interval_equals_to(interval_x);
160 }
161 
162 const std::vector<batch_result> empty_batch_results;
163 
164 const std::vector<batch_result>& burst_result::get_batches() const {
165  return p_ ? p_->get_batches() : empty_batch_results;
166 }
167 
168 const batch_result& burst_result::get_batch_at(double pos) const {
169  int i = p_ ? p_->get_index(pos) : -1;
170  if (i < 0) {
171  throw std::out_of_range("burst_result: pos is out of range");
172  }
173  return p_->get_batches()[i];
174 }
175 
176 bool burst_result::is_bursted_at(double pos) const {
177  int i = p_ ? p_->get_index(pos) : -1;
178  if (i < 0) {
179  return false;
180  }
181  return p_->get_batches()[i].is_bursted();
182 }
183 
185  if (!p_) {
186  return false;
187  }
188  const std::vector<batch_result>& batches = p_->get_batches();
189  return !batches.empty() && batches.back().is_bursted();
190 }
191 
195  get_batch_size() != w.get_batch_size()) {
196  return false;
197  }
198 
199  if (p_ != w.p_) {
200  if (accumulate_d_vec(*this) < accumulate_d_vec(w)) {
201  p_ = w.p_;
202  }
203  }
204 
205  return true;
206 }
207 
209  if (!p_) {
211  packer.pack(r);
212  } else {
213  packer.pack(*p_);
214  }
215 }
216 
217 void burst_result::msgpack_unpack(msgpack::object o) {
218  shared_ptr<result_window> unpacked(new result_window());
219  unpacked->msgpack_unpack(o);
220  p_ = unpacked;
221 }
222 
223 } // namespace burst
224 } // namespace core
225 } // namespace jubatus
bool mix(const burst_result &w)
bool has_same_start_pos_to(double start_pos) const
void msgpack_unpack(msgpack::object o)
bool has_same_batch_interval(const burst_result &x) const
bool contains(double pos) const
bool is_bursted_at(double pos) const
bool has_start_pos_newer_than(double start_pos) const
const std::vector< batch_result > & get_batches() const
std::pair< int, int > get_intersection(const W1 &w1, const W2 &w2)
bool has_batch_interval_equals_to(double interval1) const
void msgpack_pack(framework::packer &packer) const
const std::vector< batch_result > empty_batch_results
msgpack::packer< jubatus_packer > packer
Definition: bandit_base.hpp:31
bool has_start_pos_older_than(double start_pos) const
const std::vector< batch_type > & get_batches() const
bool window_position_near(double pos0, double pos1, double batch_interval)
const batch_result & get_batch_at(double pos) const
jubatus::util::lang::shared_ptr< const result_window > p_
void burst_detect(const std::vector< uint32_t > &d_vector, const std::vector< uint32_t > &r_vector, std::vector< double > &batch_weights, double scaling_param, double gamma, double burst_cut_threshold)
Definition: engine.cpp:159