jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
ucb1.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "ucb1.hpp"
18 
19 #include <string>
20 #include <vector>
21 #include <cfloat>
22 #include "../common/exception.hpp"
23 #include "../framework/packer.hpp"
24 #include "../common/version.hpp"
25 
26 namespace jubatus {
27 namespace core {
28 namespace bandit {
29 
30 ucb1::ucb1(bool assume_unrewarded)
31  : s_(assume_unrewarded) {
32 }
33 
34 std::string ucb1::select_arm(const std::string& player_id) {
35  const std::vector<std::string>& arms = s_.get_arm_ids();
36  if (arms.empty()) {
37  throw JUBATUS_EXCEPTION(
38  common::exception::runtime_error("arm is not registered"));
39  }
40 
41  int total_trial = 0;
42  for (size_t i = 0; i < arms.size(); ++i) {
43  const arm_info& a = s_.get_arm_info(player_id, arms[i]);
44  if (a.trial_count == 0) {
45  return arms[i];
46  }
47  total_trial += a.trial_count;
48  }
49  double log_total_trial = std::log(total_trial);
50 
51  double score_max = -DBL_MAX;
52  std::string result;
53  for (size_t i = 0; i < arms.size(); ++i) {
54  const arm_info& a = s_.get_arm_info(player_id, arms[i]);
55  double exp = a.weight / a.trial_count;
56  double score = exp + std::sqrt(2 * log_total_trial / a.trial_count);
57  if (score > score_max) {
58  score_max = score;
59  result = arms[i];
60  }
61  }
62  s_.notify_selected(player_id, result);
63  return result;
64 }
65 
66 bool ucb1::register_arm(const std::string& arm_id) {
67  return s_.register_arm(arm_id);
68 }
69 bool ucb1::delete_arm(const std::string& arm_id) {
70  return s_.delete_arm(arm_id);
71 }
72 
73 bool ucb1::register_reward(const std::string& player_id,
74  const std::string& arm_id,
75  double reward) {
76  return s_.register_reward(player_id, arm_id, reward);
77 }
78 
79 arm_info_map ucb1::get_arm_info(const std::string& player_id) const {
80  return s_.get_arm_info_map(player_id);
81 }
82 
83 bool ucb1::reset(const std::string& player_id) {
84  return s_.reset(player_id);
85 }
86 void ucb1::clear() {
87  s_.clear();
88 }
89 
90 void ucb1::pack(framework::packer& pk) const {
91  pk.pack(s_);
92 }
93 void ucb1::unpack(msgpack::object o) {
94  o.convert(&s_);
95 }
96 
97 void ucb1::get_diff(diff_t& diff) const {
98  s_.get_diff(diff);
99 }
100 bool ucb1::put_diff(const diff_t& diff) {
101  return s_.put_diff(diff);
102 }
103 void ucb1::mix(const diff_t& lhs, diff_t& rhs) const {
104  s_.mix(lhs, rhs);
105 }
107  return storage::version();
108 }
109 
110 } // namespace bandit
111 } // namespace core
112 } // namespace jubatus
void unpack(msgpack::object o)
Definition: ucb1.cpp:93
bool reset(const std::string &player_id)
bool delete_arm(const std::string &arm_id)
Definition: ucb1.cpp:69
storage::version get_version() const
Definition: ucb1.cpp:106
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
bool put_diff(const diff_t &diff)
Definition: ucb1.cpp:100
ucb1(bool assume_unrewarded)
Definition: ucb1.cpp:30
arm_info_map get_arm_info(const std::string &player_id) const
Definition: ucb1.cpp:79
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
std::string select_arm(const std::string &player_id)
Definition: ucb1.cpp:34
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
Definition: ucb1.cpp:73
jubatus::util::data::unordered_map< std::string, arm_info_map > diff_t
Definition: bandit_base.hpp:65
bool reset(const std::string &player_id)
Definition: ucb1.cpp:83
const std::vector< std::string > & get_arm_ids() const
void notify_selected(const std::string &player_id, const std::string &arm_id)
bool register_arm(const std::string &arm_id)
jubatus::util::data::unordered_map< std::string, arm_info > arm_info_map
Definition: arm_info.hpp:36
void get_diff(diff_t &diff) const
Definition: ucb1.cpp:97
static void mix(const table_t &lhs, table_t &rhs)
arm_info get_arm_info(const std::string &player_id, const std::string &arm_id) const
arm_info_map get_arm_info_map(const std::string &player_id) const
bool delete_arm(const std::string &arm_id)
void mix(const diff_t &lhs, diff_t &rhs) const
Definition: ucb1.cpp:103
summation_storage s_
Definition: ucb1.hpp:60
void pack(framework::packer &pk) const
Definition: ucb1.cpp:90
bool register_arm(const std::string &arm_id)
Definition: ucb1.cpp:66