jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
euclid_lsh.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2012 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "euclid_lsh.hpp"
18 
19 #include <cmath>
20 #include <queue>
21 #include <utility>
22 #include <string>
23 #include <vector>
24 #include "jubatus/util/data/serialization.h"
25 #include "jubatus/util/lang/cast.h"
26 #include "jubatus/util/math/random.h"
27 #include "jubatus/util/concurrent/lock.h"
28 #include "jubatus/util/concurrent/mutex.h"
29 #include "../common/hash.hpp"
30 #include "../storage/lsh_util.hpp"
31 #include "../storage/lsh_vector.hpp"
32 #include "../storage/lsh_index_storage.hpp"
33 
34 using std::string;
35 using std::vector;
36 using std::pair;
37 using std::ostream;
38 using std::istream;
39 using jubatus::util::math::random::mtrand;
40 using jubatus::util::concurrent::scoped_lock;
41 
42 namespace jubatus {
43 namespace core {
44 namespace recommender {
45 
46 namespace {
47 
48 struct greater_second {
49  bool operator()(
50  const pair<string, float>& l,
51  const pair<string, float>& r) const {
52  return l.second > r.second;
53  }
54 };
55 
56 float calc_norm(const common::sfv_t& sfv) {
57  float sqnorm = 0;
58  for (size_t i = 0; i < sfv.size(); ++i) {
59  sqnorm += sfv[i].second * sfv[i].second;
60  }
61  return std::sqrt(sqnorm);
62 }
63 
64 void calc_projection(uint32_t seed, size_t size, vector<float>& ret) {
65  mtrand rnd(seed);
66  ret.resize(size);
67  for (size_t i = 0; i < size; ++i) {
68  ret[i] = rnd.next_gaussian();
69  }
70 }
71 
72 } // namespace
73 
75  : hash_num(DEFAULT_HASH_NUM),
76  table_num(DEFAULT_TABLE_NUM),
77  bin_width(DEFAULT_BIN_WIDTH),
78  probe_num(DEFAULT_NUM_PROBE),
79  seed(DEFAULT_SEED),
80  retain_projection(DEFAULT_RETAIN_PROJECTION) {
81 }
82 
83 const uint64_t euclid_lsh::DEFAULT_HASH_NUM = 64; // should be in config
84 const uint64_t euclid_lsh::DEFAULT_TABLE_NUM = 4; // should be in config
85 const float euclid_lsh::DEFAULT_BIN_WIDTH = 100; // should be in config
86 const uint32_t euclid_lsh::DEFAULT_NUM_PROBE = 64; // should be in config
87 const uint32_t euclid_lsh::DEFAULT_SEED = 1091; // should be in config
88 const bool euclid_lsh::DEFAULT_RETAIN_PROJECTION = false;
89 
92 
95  model_ptr(new lsh_index_storage(DEFAULT_HASH_NUM,
97  DEFAULT_SEED)))),
101 }
102 
104  : mixable_storage_(),
105  bin_width_(config.bin_width),
106  num_probe_(config.probe_num),
107  retain_projection_(config.retain_projection) {
108 
109  if (!(1 <= config.hash_num)) {
110  throw JUBATUS_EXCEPTION(
111  common::invalid_parameter("1 <= hash_num"));
112  }
113 
114  if (!(1 <= config.table_num)) {
115  throw JUBATUS_EXCEPTION(
116  common::invalid_parameter("1 <= table_num"));
117  }
118 
119  if (!(0.f < config.bin_width)) {
120  throw JUBATUS_EXCEPTION(
121  common::invalid_parameter("0.0 < bin_width"));
122  }
123 
124  if (!(0 <= config.probe_num)) {
125  throw JUBATUS_EXCEPTION(
126  common::invalid_parameter("0 <= probe_num"));
127  }
128 
129  if (!(0 <= config.seed)) {
130  throw JUBATUS_EXCEPTION(
131  common::invalid_parameter("0 <= seed"));
132  }
133 
134  typedef storage::mixable_lsh_index_storage mli_storage;
136  typedef storage::lsh_index_storage li_storage;
137 
138  model_ptr p(new li_storage(config.hash_num, config.table_num, config.seed));
139  mixable_storage_.reset(new mli_storage(p));
140 }
141 
143 }
144 
146  const common::sfv_t& query,
147  vector<pair<string, float> >& ids,
148  size_t ret_num) const {
149  similar_row(query, ids, ret_num);
150  for (size_t i = 0; i < ids.size(); ++i) {
151  ids[i].second = -ids[i].second;
152  }
153 }
154 
156  const string& id,
157  vector<pair<string, float> >& ids,
158  size_t ret_num) const {
159  similar_row(id, ids, ret_num);
160  for (size_t i = 0; i < ids.size(); ++i) {
161  ids[i].second = -ids[i].second;
162  }
163 }
164 
166  const common::sfv_t& query,
167  vector<pair<string, float> >& ids,
168  size_t ret_num) const {
169  storage::lsh_index_storage& lsh_index = *mixable_storage_->get_model();
170  ids.clear();
171 
172  const vector<float> hash = calculate_lsh(query);
173  const float norm = calc_norm(query);
174  lsh_index.similar_row(hash, norm, num_probe_, ret_num, ids);
175 }
176 
178  const string& id,
179  vector<pair<string, float> >& ids,
180  size_t ret_num) const {
181  ids.clear();
182  mixable_storage_->get_model()->similar_row(id, ret_num, ids);
183 }
184 
186  orig_.clear();
187  mixable_storage_->get_model()->clear();
188 
189  // Clear projection cache
190  jubatus::util::data::unordered_map<uint32_t, std::vector<float> >()
192 }
193 
194 void euclid_lsh::clear_row(const string& id) {
195  orig_.remove_row(id);
196  mixable_storage_->get_model()->remove_row(id);
197 }
198 
199 void euclid_lsh::update_row(const string& id, const sfv_diff_t& diff) {
200  storage::lsh_index_storage& lsh_index = *mixable_storage_->get_model();
201  orig_.set_row(id, diff);
202  common::sfv_t row;
203  orig_.get_row(id, row);
204 
205  const vector<float> hash = calculate_lsh(row);
206  const float norm = calc_norm(row);
207  lsh_index.set_row(id, hash, norm);
208 }
209 
210 void euclid_lsh::get_all_row_ids(vector<string>& ids) const {
211  mixable_storage_->get_model()->get_all_row_ids(ids);
212 }
213 
214 string euclid_lsh::type() const {
215  return "euclid_lsh";
216 }
217 
219  return mixable_storage_.get();
220 }
221 
222 vector<float> euclid_lsh::calculate_lsh(const common::sfv_t& query) const {
223  vector<float> hash(mixable_storage_->get_model()->all_lsh_num());
224  for (size_t i = 0; i < query.size(); ++i) {
225  const uint32_t seed = common::hash_util::calc_string_hash(query[i].first);
226  const vector<float> proj = get_projection(seed);
227  for (size_t j = 0; j < hash.size(); ++j) {
228  hash[j] += query[i].second * proj[j];
229  }
230  }
231  for (size_t j = 0; j < hash.size(); ++j) {
232  hash[j] /= bin_width_;
233  }
234  return hash;
235 }
236 
237 vector<float> euclid_lsh::get_projection(uint32_t seed) const {
238  if (retain_projection_) {
239  scoped_lock lk(cache_lock_); // lock is needed only retain_projection
240  vector<float>& proj = projection_cache_[seed];
241  if (!proj.empty()) {
242  return proj;
243  }
244  calc_projection(seed, mixable_storage_->get_model()->all_lsh_num(), proj);
245  return proj;
246  } else {
247  vector<float> proj;
248  calc_projection(seed, mixable_storage_->get_model()->all_lsh_num(), proj);
249  return proj;
250  }
251 }
252 
255  model_ptr p(new storage::lsh_index_storage);
257 }
258 
260  packer.pack_array(2);
261  orig_.pack(packer);
262  mixable_storage_->get_model()->pack(packer);
263 }
264 
265 void euclid_lsh::unpack(msgpack::object o) {
266  if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) {
267  throw msgpack::type_error();
268  }
269  orig_.unpack(o.via.array.ptr[0]);
270  mixable_storage_->get_model()->unpack(o.via.array.ptr[1]);
271 }
272 
273 } // namespace recommender
274 } // namespace core
275 } // namespace jubatus
storage::lsh_index_storage lsh_index_storage
Definition: euclid_lsh.cpp:91
framework::linear_mixable_helper< lsh_index_storage, lsh_master_table_t > mixable_lsh_index_storage
Definition: euclid_lsh.hpp:43
void get_row(const std::string &row, std::vector< std::pair< std::string, float > > &columns) const
static const uint64_t DEFAULT_TABLE_NUM
Definition: euclid_lsh.hpp:53
virtual void similar_row(const common::sfv_t &query, std::vector< std::pair< std::string, float > > &ids, size_t ret_num) const
jubatus::util::data::unordered_map< uint32_t, std::vector< float > > projection_cache_
Definition: euclid_lsh.hpp:127
static const uint64_t DEFAULT_HASH_NUM
Definition: euclid_lsh.hpp:52
virtual void neighbor_row(const common::sfv_t &query, std::vector< std::pair< std::string, float > > &ids, size_t ret_num) const
virtual void get_all_row_ids(std::vector< std::string > &ids) const
Definition: euclid_lsh.cpp:210
jubatus::util::lang::shared_ptr< Model > model_ptr
core::common::sfv_t sfv_diff_t
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
virtual std::string type() const
Definition: euclid_lsh.cpp:214
void pack(framework::packer &packer) const
void pack(framework::packer &packer) const
Definition: euclid_lsh.cpp:259
static const uint32_t DEFAULT_SEED
Definition: euclid_lsh.hpp:56
void set_row(const std::string &row, const std::vector< float > &hash, float norm)
void swap(weighted_point &p1, weighted_point &p2)
Definition: types.hpp:47
static const bool DEFAULT_RETAIN_PROJECTION
Definition: euclid_lsh.hpp:57
msgpack::packer< jubatus_packer > packer
Definition: bandit_base.hpp:31
std::vector< float > get_projection(uint32_t seed) const
Definition: euclid_lsh.cpp:237
virtual void update_row(const std::string &id, const sfv_diff_t &diff)
Definition: euclid_lsh.cpp:199
static const uint64_t DEFAULT_HASH_NUM
Definition: lsh.cpp:38
core::storage::sparse_matrix_storage orig_
std::vector< std::pair< std::string, float > > sfv_t
Definition: type.hpp:29
static const uint32_t DEFAULT_NUM_PROBE
Definition: euclid_lsh.hpp:55
std::vector< float > calculate_lsh(const common::sfv_t &query) const
Definition: euclid_lsh.cpp:222
virtual void clear_row(const std::string &id)
Definition: euclid_lsh.cpp:194
jubatus::util::concurrent::mutex cache_lock_
Definition: euclid_lsh.hpp:128
void set_row(const std::string &row, const std::vector< std::pair< std::string, float > > &columns)
static uint64_t calc_string_hash(const std::string &s)
Definition: hash.hpp:29
framework::mixable * get_mixable() const
Definition: euclid_lsh.cpp:218
jubatus::util::lang::shared_ptr< storage::mixable_lsh_index_storage > mixable_storage_
Definition: euclid_lsh.hpp:122
storage::mixable_lsh_index_storage::model_ptr model_ptr
Definition: euclid_lsh.cpp:90