jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
epsilon_greedy.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "epsilon_greedy.hpp"
18 
19 #include <string>
20 #include <vector>
21 #include "../common/exception.hpp"
22 #include "../framework/packer.hpp"
23 #include "../common/version.hpp"
24 
25 namespace jubatus {
26 namespace core {
27 namespace bandit {
28 
29 epsilon_greedy::epsilon_greedy(bool assume_unrewarded, double eps)
30  : eps_(eps), s_(assume_unrewarded) {
31  if (eps < 0 || 1 < eps) {
32  throw JUBATUS_EXCEPTION(
33  common::invalid_parameter("0 <= epsilon <= 1"));
34  }
35 }
36 
37 std::string epsilon_greedy::select_arm(const std::string& player_id) {
38  const std::vector<std::string>& arms = s_.get_arm_ids();
39  if (arms.empty()) {
40  throw JUBATUS_EXCEPTION(
41  common::exception::runtime_error("arm is not registered"));
42  }
43 
44  std::string result;
45  if (rand_.next_double() < eps_) {
46  // exploration
47  result = arms[rand_.next_int(arms.size())];
48  } else {
49  // exploitation
50  result = arms[0];
51  double exp_max = s_.get_expectation(player_id, arms[0]);
52  for (size_t i = 1; i < arms.size(); ++i) {
53  double exp = s_.get_expectation(player_id, arms[i]);
54  if (exp > exp_max) {
55  result = arms[i];
56  exp_max = exp;
57  }
58  }
59  }
60  s_.notify_selected(player_id, result);
61  return result;
62 }
63 
64 bool epsilon_greedy::register_arm(const std::string& arm_id) {
65  return s_.register_arm(arm_id);
66 }
67 bool epsilon_greedy::delete_arm(const std::string& arm_id) {
68  return s_.delete_arm(arm_id);
69 }
70 
71 bool epsilon_greedy::register_reward(const std::string& player_id,
72  const std::string& arm_id,
73  double reward) {
74  return s_.register_reward(player_id, arm_id, reward);
75 }
76 
77 arm_info_map epsilon_greedy::get_arm_info(const std::string& player_id) const {
78  return s_.get_arm_info_map(player_id);
79 }
80 
81 bool epsilon_greedy::reset(const std::string& player_id) {
82  return s_.reset(player_id);
83 }
85  s_.clear();
86 }
87 
89  pk.pack(s_);
90 }
91 void epsilon_greedy::unpack(msgpack::object o) {
92  o.convert(&s_);
93 }
94 
95 void epsilon_greedy::get_diff(diff_t& diff) const {
96  s_.get_diff(diff);
97 }
98 bool epsilon_greedy::put_diff(const diff_t& diff) {
99  return s_.put_diff(diff);
100 }
101 void epsilon_greedy::mix(const diff_t& lhs, diff_t& rhs) const {
102  s_.mix(lhs, rhs);
103 }
104 
106  return storage::version();
107 }
108 
109 } // namespace bandit
110 } // namespace core
111 } // namespace jubatus
arm_info_map get_arm_info(const std::string &player_id) const
void mix(const diff_t &lhs, diff_t &rhs) const
epsilon_greedy(bool assume_unrewarded, double eps)
bool reset(const std::string &player_id)
bool delete_arm(const std::string &arm_id)
std::string select_arm(const std::string &player_id)
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
bool register_arm(const std::string &arm_id)
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
jubatus::util::data::unordered_map< std::string, arm_info_map > diff_t
Definition: bandit_base.hpp:65
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
double get_expectation(const std::string &player_id, const std::string &arm_id) const
const std::vector< std::string > & get_arm_ids() const
void notify_selected(const std::string &player_id, const std::string &arm_id)
bool register_arm(const std::string &arm_id)
storage::version get_version() const
jubatus::util::data::unordered_map< std::string, arm_info > arm_info_map
Definition: arm_info.hpp:36
jubatus::util::math::random::mtrand rand_
static void mix(const table_t &lhs, table_t &rhs)
void pack(framework::packer &pk) const
arm_info_map get_arm_info_map(const std::string &player_id) const
bool delete_arm(const std::string &arm_id)
void get_diff(diff_t &diff) const
bool reset(const std::string &player_id)