jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
exp3.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "exp3.hpp"
18 
19 #include <string>
20 #include <vector>
21 #include "../common/exception.hpp"
22 #include "../common/version.hpp"
23 #include "../framework/packer.hpp"
24 #include "select_by_weights.hpp"
25 
26 namespace jubatus {
27 namespace core {
28 namespace bandit {
29 
30 exp3::exp3(bool assume_unrewarded, double gamma)
31  : gamma_(gamma), s_(assume_unrewarded) {
32  if (gamma < 0 || 1 < gamma) {
33  throw JUBATUS_EXCEPTION(
34  common::invalid_parameter("0 <= gamma <= 1"));
35  }
36 }
37 
38 void exp3::calc_weights_(const std::string& player_id,
39  std::vector<double>& weights) const {
40  const std::vector<std::string>& arms = s_.get_arm_ids();
41  if (arms.empty()) {
42  throw JUBATUS_EXCEPTION(
43  common::exception::runtime_error("arm is not registered"));
44  }
45 
46  const size_t n = arms.size();
47  weights.clear();
48  weights.reserve(n);
49  double total_weight = 0;
50  for (size_t i = 0; i < n; ++i) {
51  const double weight = std::exp(s_.get_arm_info(player_id, arms[i]).weight);
52  weights.push_back(weight);
53  total_weight += weight;
54  }
55  for (size_t i = 0; i < n; ++i) {
56  weights[i] = (1.0 - gamma_) * weights[i] / total_weight + gamma_ * n;
57  }
58 }
59 
60 std::string exp3::select_arm(const std::string& player_id) {
61  const std::vector<std::string>& arms = s_.get_arm_ids();
62  if (arms.empty()) {
63  throw JUBATUS_EXCEPTION(
64  common::exception::runtime_error("arm is not registered"));
65  }
66 
67  std::vector<double> weights;
68  calc_weights_(player_id, weights);
69  std::string result = arms[select_by_weights(weights, rand_)];
70  s_.notify_selected(player_id, result);
71  return result;
72 }
73 
74 bool exp3::register_arm(const std::string& arm_id) {
75  return s_.register_arm(arm_id);
76 }
77 bool exp3::delete_arm(const std::string& arm_id) {
78  return s_.delete_arm(arm_id);
79 }
80 
81 bool exp3::register_reward(const std::string& player_id,
82  const std::string& arm_id,
83  double reward) {
84  const std::vector<std::string>& arms = s_.get_arm_ids();
85  size_t i = std::find(arms.begin(), arms.end(), arm_id) - arms.begin();
86  if (i >= arms.size()) {
87  return false;
88  }
89  std::vector<double> weights;
90  calc_weights_(player_id, weights);
91  return s_.register_reward(player_id, arm_id,
92  reward * weights[i] * gamma_ / arms.size());
93 }
94 
95 arm_info_map exp3::get_arm_info(const std::string& player_id) const {
96  return s_.get_arm_info_map(player_id);
97 }
98 
99 bool exp3::reset(const std::string& player_id) {
100  return s_.reset(player_id);
101 }
102 void exp3::clear() {
103  s_.clear();
104 }
105 
106 void exp3::pack(framework::packer& pk) const {
107  pk.pack(s_);
108 }
109 void exp3::unpack(msgpack::object o) {
110  o.convert(&s_);
111 }
112 
113 void exp3::get_diff(diff_t& diff) const {
114  s_.get_diff(diff);
115 }
116 bool exp3::put_diff(const diff_t& diff) {
117  return s_.put_diff(diff);
118 }
119 void exp3::mix(const diff_t& lhs, diff_t& rhs) const {
120  s_.mix(lhs, rhs);
121 }
122 
124  return storage::version();
125 }
126 
127 } // namespace bandit
128 } // namespace core
129 } // namespace jubatus
storage::version get_version() const
Definition: exp3.cpp:123
summation_storage s_
Definition: exp3.hpp:64
std::string select_arm(const std::string &player_id)
Definition: exp3.cpp:60
int select_by_weights(const std::vector< double > &weights, mtrand &rand)
bool reset(const std::string &player_id)
arm_info_map get_arm_info(const std::string &player_id) const
Definition: exp3.cpp:95
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
Definition: exp3.cpp:81
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
bool put_diff(const diff_t &diff)
Definition: exp3.cpp:116
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
exp3(bool assume_unrewarded, double gamma)
Definition: exp3.cpp:30
bool register_arm(const std::string &arm_id)
Definition: exp3.cpp:74
jubatus::util::data::unordered_map< std::string, arm_info_map > diff_t
Definition: bandit_base.hpp:65
const std::vector< std::string > & get_arm_ids() const
void notify_selected(const std::string &player_id, const std::string &arm_id)
bool register_arm(const std::string &arm_id)
void mix(const diff_t &lhs, diff_t &rhs) const
Definition: exp3.cpp:119
jubatus::util::data::unordered_map< std::string, arm_info > arm_info_map
Definition: arm_info.hpp:36
bool delete_arm(const std::string &arm_id)
Definition: exp3.cpp:77
void unpack(msgpack::object o)
Definition: exp3.cpp:109
bool reset(const std::string &player_id)
Definition: exp3.cpp:99
void get_diff(diff_t &diff) const
Definition: exp3.cpp:113
static void mix(const table_t &lhs, table_t &rhs)
arm_info get_arm_info(const std::string &player_id, const std::string &arm_id) const
arm_info_map get_arm_info_map(const std::string &player_id) const
bool delete_arm(const std::string &arm_id)
void calc_weights_(const std::string &player_id, std::vector< double > &weights) const
Definition: exp3.cpp:38
jubatus::util::math::random::mtrand rand_
Definition: exp3.hpp:63
void pack(framework::packer &pk) const
Definition: exp3.cpp:106