jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
summation_storage.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "summation_storage.hpp"
18 
19 #include <string>
20 #include <vector>
21 #include "../common/exception.hpp"
22 
23 namespace jubatus {
24 namespace core {
25 namespace bandit {
26 
27 summation_storage::summation_storage(bool assume_unrewarded)
28  : assume_unrewarded_(assume_unrewarded) {
29 }
30 
31 bool summation_storage::register_arm(const std::string& arm_id) {
32  if (std::find(arm_ids_.begin(), arm_ids_.end(), arm_id) != arm_ids_.end()) {
33  // arm_id is already in arms_
34  return false;
35  }
36  arm_ids_.push_back(arm_id);
37  const arm_info a0 = {0, 0.0};
38  for (table_t::iterator iter = unmixed_.begin();
39  iter != unmixed_.end(); ++iter) {
40  arm_info_map& as = iter->second;
41  as.insert(std::make_pair(arm_id, a0));
42  }
43  return true;
44 }
45 
46 namespace {
47 void delete_arm_(summation_storage::table_t& t, const std::string& arm_id) {
48  for (summation_storage::table_t::iterator iter = t.begin();
49  iter != t.end(); ++iter) {
50  arm_info_map& as = iter->second;
51  as.erase(arm_id);
52  }
53 }
54 arm_info_map& get_arm_info_map_(summation_storage::table_t& t,
55  const std::vector<std::string>& arm_ids,
56  const std::string& player_id) {
57  summation_storage::table_t::iterator iter = t.find(player_id);
58  if (iter != t.end()) {
59  return iter->second;
60  }
61  arm_info_map& as = t[player_id];
62  const arm_info a0 = {0, 0.0};
63  for (size_t i = 0; i < arm_ids.size(); ++i) {
64  as.insert(std::make_pair(arm_ids[i], a0));
65  }
66  return as;
67 }
68 arm_info& get_arm_info_(summation_storage::table_t& t,
69  const std::vector<std::string>& arm_ids,
70  const std::string& player_id,
71  const std::string& arm_id) {
72  arm_info_map& as = get_arm_info_map_(t, arm_ids, player_id);
73  arm_info_map::iterator iter = as.find(arm_id);
74  if (iter == as.end()) {
75  throw JUBATUS_EXCEPTION(common::exception::runtime_error(
76  "arm_id is not registered: " + arm_id));
77  }
78  return iter->second;
79 }
80 } // namespace
81 
82 bool summation_storage::delete_arm(const std::string& arm_id) {
83  delete_arm_(mixed_, arm_id);
84  delete_arm_(unmixed_, arm_id);
85 
86  std::vector<std::string>::iterator iter =
87  std::remove(arm_ids_.begin(), arm_ids_.end(), arm_id);
88  if (iter == arm_ids_.end()) {
89  return false;
90  }
91  arm_ids_.erase(iter, arm_ids_.end());
92  return true;
93 }
94 
96  const std::string& player_id,
97  const std::string& arm_id) {
98  if (!assume_unrewarded_) {
99  return;
100  }
101  arm_info& a = get_arm_info_(unmixed_, arm_ids_, player_id, arm_id);
102  a.trial_count += 1;
103 }
104 
106  const std::string& player_id,
107  const std::string& arm_id,
108  double reward) {
109  arm_info& a = get_arm_info_(unmixed_, arm_ids_, player_id, arm_id);
110  if (!assume_unrewarded_) {
111  a.trial_count += 1;
112  }
113  a.weight += reward;
114  return true;
115 }
116 
117 namespace {
118 arm_info get_arm_info_(
120  const std::string& player_id,
121  const std::string& arm_id) {
122  summation_storage::table_t::const_iterator iter = t.find(player_id);
123  if (iter == t.end()) {
124  const arm_info a0 = {0, 0.0};
125  return a0;
126  }
127  const arm_info_map& as = iter->second;
128  arm_info_map::const_iterator jter = as.find(arm_id);
129  if (jter == as.end()) {
130  const arm_info a0 = {0, 0.0};
131  return a0;
132  }
133  return jter->second;
134 }
135 } // namespace
136 
138  const std::string& player_id,
139  const std::string& arm_id) const {
140  const arm_info a1 = get_arm_info_(mixed_, player_id, arm_id);
141  const arm_info a2 = get_arm_info_(unmixed_, player_id, arm_id);
142 
143  arm_info result;
144  result.trial_count = a1.trial_count + a2.trial_count;
145  result.weight = a1.weight + a2.weight;
146  return result;
147 }
148 
150  const std::string& player_id,
151  const std::string& arm_id) const {
152  const arm_info a = get_arm_info(player_id, arm_id);
153  if (a.trial_count == 0) {
154  return 0;
155  }
156  return a.weight / a.trial_count;
157 }
158 
160  const std::string& player_id) const {
161  arm_info_map result;
162 
163  for (std::vector<std::string>::const_iterator iter = arm_ids_.begin();
164  iter != arm_ids_.end(); ++iter) {
165  result.insert(std::make_pair(*iter, get_arm_info(player_id, *iter)));
166  }
167 
168  return result;
169 }
170 
172  diff = unmixed_;
173 }
174 
176  mix(diff, mixed_);
177  unmixed_.clear();
178  return true;
179 }
180 
181 void summation_storage::mix(const table_t& lhs, table_t& rhs) {
182  for (table_t::const_iterator iter = lhs.begin();
183  iter != lhs.end(); ++iter) {
184  arm_info_map& as0 = rhs[iter->first];
185  const arm_info_map& as1 = iter->second;
186  for (arm_info_map::const_iterator jter = as1.begin();
187  jter != as1.end(); ++jter) {
188  arm_info& a0 = as0[jter->first];
189  const arm_info& a1 = jter->second;
190  a0.trial_count += a1.trial_count;
191  a0.weight += a1.weight;
192  }
193  }
194 }
195 
196 bool summation_storage::reset(const std::string& player_id) {
197  bool result1 = mixed_.erase(player_id) > 0;
198  bool result2 = unmixed_.erase(player_id) > 0;
199  return result1 || result2;
200 }
201 
203  arm_ids_.clear();
204  mixed_.clear();
205  unmixed_.clear();
206 }
207 
208 } // namespace bandit
209 } // namespace core
210 } // namespace jubatus
bool reset(const std::string &player_id)
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
double get_expectation(const std::string &player_id, const std::string &arm_id) const
void notify_selected(const std::string &player_id, const std::string &arm_id)
bool register_arm(const std::string &arm_id)
jubatus::util::data::unordered_map< std::string, arm_info > arm_info_map
Definition: arm_info.hpp:36
static void mix(const table_t &lhs, table_t &rhs)
arm_info get_arm_info(const std::string &player_id, const std::string &arm_id) const
arm_info_map get_arm_info_map(const std::string &player_id) const
bool delete_arm(const std::string &arm_id)