29 #include "../common/exception.hpp"
38 const int kStatesNum = 2;
40 const int kBaseState = 0;
42 const int kBurstState = 1;
44 const double kDefaultBatchWeight = -1;
46 std::vector<double> get_p_vector(
47 const std::vector<uint32_t>& d_vector,
48 const std::vector<uint32_t>& r_vector,
49 double scaling_param) {
50 std::vector<double> ret(kStatesNum);
52 const int32_t D = std::accumulate(d_vector.begin(), d_vector.end(), 0);
54 const int32_t R = std::accumulate(r_vector.begin(), r_vector.end(), 0);
55 ret[kBaseState] =
static_cast<double>(R) / static_cast<double>(D);
56 ret[kBurstState] = scaling_param * ret[kBaseState];
60 double tau(
int i,
int j,
double gamma,
int window_size) {
64 return (j - i) * gamma * std::log(window_size);
67 double log_factorial(
int i) {
68 return ::lgamma(i + 1);
71 double log_choose(
int n,
int k) {
72 if (n < 0 || k < 0 || n < k) {
75 return log_factorial(n) - log_factorial(k) - log_factorial(n - k);
78 double sigma(
double p, uint32_t d, uint32_t r) {
79 double ret = log_choose(d, r);
80 ret += r * std::log(p);
81 ret += (d - r) * std::log(1 - p);
85 double get_batch_weight(
86 const std::vector<uint32_t>& d_vector,
87 const std::vector<uint32_t>& r_vector,
88 const std::vector<double>& p_vector,
91 sigma(p_vector[kBaseState], d_vector[batch_id], r_vector[batch_id]) -
92 sigma(p_vector[kBurstState], d_vector[batch_id], r_vector[batch_id]);
93 return (ret > 0) ? ret : 0;
96 std::pair<int, double> calc_previous_optimal_state(
98 double prev_base_optimal_cost,
99 double prev_burst_optimal_cost,
103 const double prev_base_optimal_to_now_state_cost
104 = prev_base_optimal_cost
105 + tau(kBaseState, now_state, gamma, window_size);
107 const double prev_burst_optimal_to_now_state_cost
108 = prev_burst_optimal_cost
109 + tau(kBurstState, now_state, gamma, window_size);
111 int prev_optimal_state = kBaseState;
112 double prev_optimal_in_now_state_cost
113 = prev_base_optimal_to_now_state_cost;
114 if (prev_base_optimal_to_now_state_cost
115 > prev_burst_optimal_to_now_state_cost) {
116 prev_optimal_state = kBurstState;
117 prev_optimal_in_now_state_cost
118 = prev_burst_optimal_to_now_state_cost;
121 return std::make_pair(prev_optimal_state,
122 prev_optimal_in_now_state_cost);
125 bool check_branch_cuttable(
126 const std::vector<uint32_t>& d_vector,
127 const std::vector<uint32_t>& r_vector,
128 const std::vector<double> & p_vector,
130 double burst_cut_threshold) {
131 const int window_size = d_vector.size();
133 if (sigma(p_vector[kBurstState],
136 - sigma(p_vector[kBaseState],
139 > burst_cut_threshold * std::log(window_size)) {
146 bool operator()(
double x)
const {
151 void erase_uncalc_batches(std::vector<double>& batch_weights) {
152 std::vector<double>::iterator iter = std::remove_if(
153 batch_weights.begin(), batch_weights.end(), is_negative());
154 batch_weights.erase(iter, batch_weights.end());
160 const std::vector<uint32_t> & r_vector,
161 std::vector<double>& batch_weights,
162 double scaling_param,
164 double burst_cut_threshold) {
165 const int window_size = d_vector.size();
170 if (scaling_param <= 1) {
174 if (burst_cut_threshold <= 0) {
178 if (d_vector.size() != r_vector.size()) {
182 for (
int batch_id = 0; batch_id < window_size; batch_id++) {
183 if (d_vector[batch_id] < r_vector[batch_id]) {
186 "d_vector[batch_id] < r_vector[batch_id]"));
189 const std::vector<double> p_vector
190 = get_p_vector(d_vector, r_vector, scaling_param);
192 erase_uncalc_batches(batch_weights);
197 if (1 < p_vector[kBurstState]) {
198 batch_weights.resize(window_size, INFINITY);
200 }
else if (p_vector[kBaseState] == 0) {
201 batch_weights.resize(window_size, 0);
205 const int reuse_batch_size = batch_weights.size();
210 double prev_optimal_in_now_states_costs[] = {-1, -1};
217 double prev_optimal_costs[] = {0, INFINITY};
218 if (batch_weights.size() != 0 && 0 < batch_weights.back()) {
221 prev_optimal_costs[kBaseState] = INFINITY;
222 prev_optimal_costs[kBurstState] = 0;
228 std::vector<std::vector<int> > prev_optimal_in_now_states_seq(kStatesNum);
233 std::vector<std::vector<int> > prev_optimal_states_seq(kStatesNum);
235 for (
int update_batch_id = 0;
236 update_batch_id < window_size - reuse_batch_size;
238 for (
int now_state = kBaseState; now_state < kStatesNum; now_state++) {
239 std::pair<int, double> prev_optimal_pair;
241 if ((0 < update_batch_id + reuse_batch_size) &&
242 (d_vector[update_batch_id + reuse_batch_size - 1] == 0)) {
246 prev_optimal_pair.first = kBaseState;
247 prev_optimal_pair.second =
248 prev_optimal_costs[kBaseState] +
249 tau(kBaseState, now_state, gamma, window_size);
250 }
else if (0 < update_batch_id + reuse_batch_size &&
251 check_branch_cuttable(d_vector, r_vector, p_vector,
252 update_batch_id + reuse_batch_size - 1,
253 burst_cut_threshold)) {
254 prev_optimal_pair.first = kBaseState;
255 prev_optimal_pair.second =
256 prev_optimal_costs[kBaseState] +
257 tau(kBaseState, now_state, gamma, window_size);
260 calc_previous_optimal_state(now_state,
261 prev_optimal_costs[kBaseState],
262 prev_optimal_costs[kBurstState],
266 prev_optimal_in_now_states_costs[now_state] =
267 prev_optimal_pair.second +
268 sigma(p_vector[now_state],
269 d_vector[update_batch_id + reuse_batch_size],
270 r_vector[update_batch_id + reuse_batch_size]);
272 prev_optimal_in_now_states_seq[now_state] =
273 prev_optimal_states_seq[prev_optimal_pair.first];
274 prev_optimal_in_now_states_seq[now_state].push_back(now_state);
280 for (
int state = kBaseState; state < kStatesNum; state++) {
281 prev_optimal_costs[state] = prev_optimal_in_now_states_costs[state];
282 prev_optimal_states_seq[state] =
283 prev_optimal_in_now_states_seq[state];
287 std::vector<int> optimal_states_seq;
289 if (d_vector[window_size - 1] == 0) {
293 optimal_states_seq = prev_optimal_in_now_states_seq[kBaseState];
294 }
else if (check_branch_cuttable(d_vector, r_vector, p_vector,
296 burst_cut_threshold)) {
297 optimal_states_seq = prev_optimal_in_now_states_seq[kBaseState];
300 prev_optimal_in_now_states_costs[kBaseState] <=
301 prev_optimal_in_now_states_costs[kBurstState] ?
302 prev_optimal_in_now_states_seq[kBaseState] :
303 prev_optimal_in_now_states_seq[kBurstState];
311 for (
int batch_id = 0; batch_id < reuse_batch_size; batch_id++) {
312 if (0 < batch_weights[batch_id]) {
313 batch_weights[batch_id] =
314 get_batch_weight(d_vector, r_vector, p_vector, batch_id);
318 for (
int batch_id = reuse_batch_size; batch_id < window_size; batch_id++) {
319 int state = optimal_states_seq[batch_id - reuse_batch_size];
320 batch_weights.push_back(state == kBurstState ?
321 get_batch_weight(d_vector, r_vector,
322 p_vector, batch_id) :
#define JUBATUS_EXCEPTION(e)
void burst_detect(const std::vector< uint32_t > &d_vector, const std::vector< uint32_t > &r_vector, std::vector< double > &batch_weights, double scaling_param, double gamma, double burst_cut_threshold)