jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
softmax.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "softmax.hpp"
18 
19 #include <string>
20 #include <vector>
21 #include <cmath>
22 #include <numeric>
23 #include "../common/exception.hpp"
24 #include "../framework/packer.hpp"
25 #include "select_by_weights.hpp"
26 #include "../common/version.hpp"
27 
28 namespace jubatus {
29 namespace core {
30 namespace bandit {
31 
32 softmax::softmax(bool assume_unrewarded, double tau)
33  : tau_(tau), s_(assume_unrewarded) {
34  if (tau <= 0) {
35  throw JUBATUS_EXCEPTION(
36  common::invalid_parameter("0 < tau"));
37  }
38 }
39 
40 std::string softmax::select_arm(const std::string& player_id) {
41  const std::vector<std::string>& arms = s_.get_arm_ids();
42  if (arms.empty()) {
43  throw JUBATUS_EXCEPTION(
44  common::exception::runtime_error("arm is not registered"));
45  }
46 
47  std::vector<double> weights;
48  weights.reserve(arms.size());
49 
50  for (size_t i = 0; i < arms.size(); ++i) {
51  double expectation = s_.get_expectation(player_id, arms[i]);
52  weights.push_back(std::exp(expectation / tau_));
53  }
54  std::string result = arms[select_by_weights(weights, rand_)];
55  s_.notify_selected(player_id, result);
56  return result;
57 }
58 
59 bool softmax::register_arm(const std::string& arm_id) {
60  return s_.register_arm(arm_id);
61 }
62 bool softmax::delete_arm(const std::string& arm_id) {
63  return s_.delete_arm(arm_id);
64 }
65 
66 bool softmax::register_reward(const std::string& player_id,
67  const std::string& arm_id,
68  double reward) {
69  return s_.register_reward(player_id, arm_id, reward);
70 }
71 
72 arm_info_map softmax::get_arm_info(const std::string& player_id) const {
73  return s_.get_arm_info_map(player_id);
74 }
75 
76 bool softmax::reset(const std::string& player_id) {
77  return s_.reset(player_id);
78 }
80  s_.clear();
81 }
82 
84  pk.pack(s_);
85 }
86 void softmax::unpack(msgpack::object o) {
87  o.convert(&s_);
88 }
89 
90 void softmax::get_diff(diff_t& diff) const {
91  s_.get_diff(diff);
92 }
93 bool softmax::put_diff(const diff_t& diff) {
94  return s_.put_diff(diff);
95 }
96 void softmax::mix(const diff_t& lhs, diff_t& rhs) const {
97  s_.mix(lhs, rhs);
98 }
99 
101  return storage::version();
102 }
103 
104 } // namespace bandit
105 } // namespace core
106 } // namespace jubatus
bool put_diff(const diff_t &diff)
Definition: softmax.cpp:93
jubatus::util::math::random::mtrand rand_
Definition: softmax.hpp:62
bool reset(const std::string &player_id)
Definition: softmax.cpp:76
void get_diff(diff_t &diff) const
Definition: softmax.cpp:90
int select_by_weights(const std::vector< double > &weights, mtrand &rand)
bool reset(const std::string &player_id)
bool register_arm(const std::string &arm_id)
Definition: softmax.cpp:59
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
void pack(framework::packer &pk) const
Definition: softmax.cpp:83
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
jubatus::util::data::unordered_map< std::string, arm_info_map > diff_t
Definition: bandit_base.hpp:65
double get_expectation(const std::string &player_id, const std::string &arm_id) const
storage::version get_version() const
Definition: softmax.cpp:100
const std::vector< std::string > & get_arm_ids() const
void notify_selected(const std::string &player_id, const std::string &arm_id)
softmax(bool assume_unrewarded, double tau)
Definition: softmax.cpp:32
bool register_arm(const std::string &arm_id)
jubatus::util::data::unordered_map< std::string, arm_info > arm_info_map
Definition: arm_info.hpp:36
void unpack(msgpack::object o)
Definition: softmax.cpp:86
std::string select_arm(const std::string &player_id)
Definition: softmax.cpp:40
bool delete_arm(const std::string &arm_id)
Definition: softmax.cpp:62
void mix(const diff_t &lhs, diff_t &rhs) const
Definition: softmax.cpp:96
static void mix(const table_t &lhs, table_t &rhs)
arm_info_map get_arm_info_map(const std::string &player_id) const
summation_storage s_
Definition: softmax.hpp:63
bool delete_arm(const std::string &arm_id)
bool register_reward(const std::string &player_id, const std::string &arm_id, double reward)
Definition: softmax.cpp:66
arm_info_map get_arm_info(const std::string &player_id) const
Definition: softmax.cpp:72