jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
lsh_index_storage.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2012 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "lsh_index_storage.hpp"
18 #include <cmath>
19 #include <algorithm>
20 #include <utility>
21 #include <sstream>
22 #include <string>
23 #include <vector>
24 #include "jubatus/util/data/unordered_map.h"
25 #include "jubatus/util/data/unordered_set.h"
26 #include "jubatus/util/lang/cast.h"
27 #include "jubatus/util/math/random.h"
28 #include "lsh_util.hpp"
29 
30 using std::copy;
31 using std::ostream;
32 using std::ostringstream;
33 using std::istream;
34 using std::istringstream;
35 using std::make_pair;
36 using std::pair;
37 using std::string;
38 using std::vector;
39 using std::sort;
40 using std::partial_sort;
41 using std::lower_bound;
42 using jubatus::util::data::unordered_map;
43 using jubatus::util::data::unordered_set;
44 using jubatus::util::math::random::mtrand;
45 
46 namespace jubatus {
47 namespace core {
48 namespace storage {
49 
50 namespace {
51 
52 struct greater_second {
53  template <typename P>
54  bool operator()(const P& l, const P& r) const {
55  return l.second > r.second;
56  }
57 };
58 
59 uint64_t hash_lv(const lsh_vector& lv) {
60  uint64_t hash = 14695981039346656037LLU;
61  for (size_t i = 0; i < lv.size(); ++i) {
62  for (int j = 0; j < 32; j += 8) {
63  hash *= 1099511628211LLU;
64  hash ^= (static_cast<uint32_t>(lv.get(i)) >> j) & 0xff;
65  }
66  }
67  return hash;
68 }
69 
70 void initialize_shift(uint32_t seed, vector<float>& shift) {
71  mtrand rnd(seed);
72  for (size_t i = 0; i < shift.size(); ++i) {
73  shift[i] = rnd.next_double();
74  }
75 }
76 
77 vector<float> shift_hash(
78  const vector<float>& hash,
79  const vector<float>& shift) {
80  vector<float> shifted(hash);
81  for (size_t i = 0; i < shifted.size(); ++i) {
82  shifted[i] += shift[i];
83  }
84  return shifted;
85 }
86 
87 bit_vector binarize(const vector<float>& hash) {
88  bit_vector bv;
89  bv.resize_and_clear(hash.size());
90  for (size_t i = 0; i < hash.size(); ++i) {
91  if (hash[i] > 0) {
92  bv.set_bit(i);
93  }
94  }
95  return bv;
96 }
97 
98 float calc_euclidean_distance(
99  const lsh_entry& entry,
100  const bit_vector& bv,
101  float norm) {
102  const uint64_t hamm = bv.calc_hamming_similarity(entry.simhash_bv);
103  if (hamm == bv.bit_num()) {
104  // Avoid NaN caused by arithmetic error
105  return std::fabs(norm - entry.norm);
106  }
107  const float angle = (1 - static_cast<float>(hamm) / bv.bit_num()) * M_PI;
108  const float dot = entry.norm * norm * std::cos(angle);
109  return std::sqrt(norm * norm + entry.norm * entry.norm - 2 * dot);
110 }
111 
112 void retrieve_hit_rows_from_table(
113  uint64_t hash,
114  const lsh_table_t& table,
115  unordered_set<uint64_t>& cands) {
116  lsh_table_t::const_iterator it = table.find(hash);
117  if (it != table.end()) {
118  const vector<uint64_t>& range = it->second;
119  for (size_t j = 0; j < range.size(); ++j) {
120  cands.insert(range[j]);
121  }
122  }
123 }
124 
125 } // namespace
126 
128 }
129 
131  size_t lsh_num,
132  size_t table_num,
133  uint32_t seed)
134  : shift_(lsh_num * table_num),
135  table_num_(table_num) {
136  initialize_shift(seed, shift_);
137 }
138 
140  size_t table_num,
141  const vector<float>& shift)
142  : shift_(shift),
143  table_num_(table_num) {
144 }
145 
147 }
148 
150  const string& row,
151  const vector<float>& hash,
152  float norm) {
153  lsh_master_table_t::iterator it = remove_and_get_row(row);
154  if (it == master_table_diff_.end()) {
155  it = master_table_diff_.insert(make_pair(row, lsh_entry())).first;
156  }
157  make_entry(hash, norm, it->second);
158 
159  const uint64_t id = key_manager_.get_id(row);
160  const vector<uint64_t>& lsh_hash = it->second.lsh_hash;
161  for (size_t i = 0; i < lsh_hash.size(); ++i) {
162  vector<uint64_t>& range = lsh_table_diff_[lsh_hash[i]];
163  vector<uint64_t>::iterator it = lower_bound(range.begin(), range.end(), id);
164  if (it == range.end() || id != *it) {
165  range.insert(it, id);
166  }
167  }
168 }
169 
170 void lsh_index_storage::remove_row(const string& row) {
171  const uint64_t row_id = key_manager_.get_id_const(row);
172  if (row_id == common::key_manager::NOTFOUND) {
173  // Non-existence row
174  return;
175  }
176 
177  lsh_master_table_t::iterator entry_it = master_table_.find(row);
178  if (entry_it == master_table_.end()) {
179  // Since the row is not yet mixed, it can be immediately erased.
180  master_table_diff_.erase(row);
181  return;
182  }
183 
184  // Otherwise, keep the row with empty entry until next MIX.
185  master_table_diff_.insert(make_pair(row, lsh_entry()));
186  lsh_entry& entry = entry_it->second;
187  put_empty_entry(row_id, entry);
188 
189  return;
190 }
191 
195  lsh_table_t().swap(lsh_table_);
198 }
199 
200 void lsh_index_storage::get_all_row_ids(vector<string>& ids) const {
201  const size_t size_upper_bound = master_table_.size()
202  + master_table_diff_.size();
203 
204  unordered_set<std::string> id_set;
205  // equivalent to id_set.reserve(size_upper_bound) in C++11
206  id_set.rehash(std::ceil(size_upper_bound / id_set.max_load_factor()));
207 
208  for (lsh_master_table_t::const_iterator it = master_table_.begin();
209  it != master_table_.end(); ++it) {
210  if (!it->second.lsh_hash.empty()) {
211  id_set.insert(it->first);
212  }
213  }
214  for (lsh_master_table_t::const_iterator it = master_table_diff_.begin();
215  it != master_table_diff_.end(); ++it) {
216  if (!it->second.lsh_hash.empty()) {
217  id_set.insert(it->first);
218  }
219  }
220 
221  vector<string> ret(id_set.size());
222  copy(id_set.begin(), id_set.end(), ret.begin());
223  ids.swap(ret);
224 }
225 
227  const vector<float>& hash,
228  float norm,
229  uint64_t probe_num,
230  uint64_t ret_num,
231  vector<pair<string, float> >& ids) const {
232  const vector<float> shifted = shift_hash(hash, shift_);
233  const bit_vector bv = binarize(hash);
234 
235  lsh_probe_generator gen(shifted, table_num_);
236  unordered_set<uint64_t> cands;
237 
238  for (uint64_t i = 0; i < table_num_; ++i) {
239  lsh_vector key = gen.base(i);
240  key.push_back(i);
241  if (retrieve_hit_rows(hash_lv(key), ret_num, cands)) {
242  get_sorted_similar_rows(cands, bv, norm, ret_num, ids);
243  return;
244  }
245  }
246 
247  gen.init();
248  for (uint64_t i = 0; i < probe_num; ++i) {
249  pair<size_t, lsh_vector> p = gen.get_next_table_and_vector();
250  p.second.push_back(p.first);
251  if (retrieve_hit_rows(hash_lv(p.second), ret_num, cands)) {
252  break;
253  }
254  }
255  get_sorted_similar_rows(cands, bv, norm, ret_num, ids);
256 }
257 
259  const string& id,
260  uint64_t ret_num,
261  vector<pair<string, float> >& ids) const {
262  lsh_master_table_t::const_iterator it = master_table_diff_.find(id);
263  if (it == master_table_diff_.end()) {
264  it = master_table_.find(id);
265  if (it == master_table_.end()) {
266  return;
267  }
268  }
269 
270  unordered_set<uint64_t> cands;
271  for (size_t i = 0; i < it->second.lsh_hash.size(); ++i) {
272  if (retrieve_hit_rows(it->second.lsh_hash[i], ret_num, cands)) {
273  break;
274  }
275  }
276 
278  it->second.simhash_bv,
279  it->second.norm,
280  ret_num, ids);
281 }
282 
283 string lsh_index_storage::name() const {
284  return "lsh_index_storage";
285 }
286 
288  packer.pack(*this);
289 }
290 
291 void lsh_index_storage::unpack(msgpack::object o) {
292  o.convert(this);
293 }
294 
296  diff = master_table_diff_;
297 }
298 
300  const lsh_master_table_t& diff) {
301  for (lsh_master_table_t::const_iterator it = diff.begin(); it != diff.end();
302  ++it) {
303  if (it->second.lsh_hash.empty()) {
304  remove_model_row(it->first);
305  master_table_.erase(it->first);
306  } else {
307  remove_model_row(it->first);
308  set_mixed_row(it->first, it->second);
309  }
310  }
311 
312  master_table_diff_.clear();
313 
314  // lsh_table_diff_ is actually not MIXed, but must be cleared as well as diff
315  // of usual model.
316  lsh_table_diff_.clear();
317  return true;
318 }
319 
321  const lsh_master_table_t& lhs,
322  lsh_master_table_t& rhs) const {
323  for (lsh_master_table_t::const_iterator it = lhs.begin(); it != lhs.end();
324  ++it) {
325  rhs[it->first] = it->second;
326  }
327 }
328 
329 // private
330 
331 lsh_master_table_t::iterator lsh_index_storage::remove_and_get_row(
332  const string& row) {
333  const uint64_t row_id = key_manager_.get_id_const(row);
334  if (row_id == common::key_manager::NOTFOUND) {
335  return master_table_diff_.end();
336  }
337 
338  lsh_master_table_t::iterator entry_it = master_table_diff_.find(row);
339  lsh_master_table_t::iterator ret_it = entry_it;
340  if (entry_it == master_table_diff_.end()) {
341  ret_it = master_table_diff_.insert(make_pair(row, lsh_entry())).first;
342  entry_it = master_table_.find(row);
343  if (entry_it == master_table_.end()) {
344  return ret_it;
345  }
346  }
347  lsh_entry& entry = entry_it->second;
348  put_empty_entry(row_id, entry);
349 
350  return ret_it;
351 }
352 
354  uint64_t row_id,
355  const lsh_entry& entry) {
356  for (size_t i = 0; i < entry.lsh_hash.size(); ++i) {
357  lsh_table_t::iterator it = lsh_table_diff_.find(entry.lsh_hash[i]);
358  if (it != lsh_table_diff_.end()) {
359  vector<uint64_t>& range = it->second;
360  vector<uint64_t>::iterator jt = lower_bound(range.begin(),
361  range.end(),
362  row_id);
363  if (jt != range.end() && row_id == *jt) {
364  range.erase(jt);
365  if (range.empty()) {
366  lsh_table_diff_.erase(it);
367  }
368  }
369  }
370  }
371 }
372 
374  const vector<float>& hash,
375  float norm,
376  lsh_entry& entry) const {
377  const vector<float> shifted = shift_hash(hash, shift_);
378  lsh_probe_generator gen(shifted, table_num_);
379 
380  entry.lsh_hash.resize(table_num_);
381  for (uint64_t i = 0; i < table_num_; ++i) {
382  lsh_vector key = gen.base(i);
383  key.push_back(i);
384  entry.lsh_hash[i] = hash_lv(key);
385  }
386 
387  entry.simhash_bv = binarize(hash);
388  entry.norm = norm;
389 
390  return shifted;
391 }
392 
393 // TODO(unknown): Separate implementation detail of processing
394 // lsh_table_ into another class
395 void lsh_index_storage::remove_model_row(const std::string& row) {
396  const lsh_entry* entry = get_lsh_entry(row);
397  if (!entry) {
398  return;
399  }
400 
401  const uint64_t row_id = key_manager_.get_id_const(row);
402  for (size_t i = 0; i < entry->lsh_hash.size(); ++i) {
403  lsh_table_t::iterator it = lsh_table_.find(entry->lsh_hash[i]);
404  if (it != lsh_table_.end()) {
405  vector<uint64_t>& range = it->second;
406  vector<uint64_t>::iterator jt = find(range.begin(), range.end(), row_id);
407  if (jt != range.end()) {
408  range.erase(jt);
409  if (range.empty()) {
410  lsh_table_.erase(it);
411  }
412  }
413  }
414  }
415 }
416 
418  const string& row,
419  const lsh_entry& entry) {
420  const uint64_t row_id = key_manager_.get_id(row);
421  master_table_[row] = entry;
422  for (size_t i = 0; i < entry.lsh_hash.size(); ++i) {
423  lsh_table_[entry.lsh_hash[i]].push_back(row_id);
424  }
425 }
426 
428  uint64_t hash,
429  size_t ret_num,
430  unordered_set<uint64_t>& cands) const {
431  retrieve_hit_rows_from_table(hash, lsh_table_diff_, cands);
432  retrieve_hit_rows_from_table(hash, lsh_table_, cands);
433  return cands.size() >= static_cast<uint64_t>(ret_num);
434 }
435 
437  const unordered_set<uint64_t>& cands,
438  const bit_vector& query_simhash,
439  float query_norm,
440  uint64_t ret_num,
441  vector<pair<string, float> >& ids) const {
442  // Avoid string copy as far as possible
443  vector<pair<uint64_t, float> > scored;
444  scored.reserve(cands.size());
445  for (unordered_set<uint64_t>::const_iterator it = cands.begin();
446  it != cands.end(); ++it) {
447  const lsh_entry* entry = get_lsh_entry(key_manager_.get_key(*it));
448  if (!entry || entry->lsh_hash.empty()) {
449  continue;
450  }
451  const float dist = calc_euclidean_distance(*entry, query_simhash,
452  query_norm);
453  scored.push_back(make_pair(*it, -dist));
454  }
455 
456  if (scored.size() <= ret_num) {
457  sort(scored.begin(), scored.end(), greater_second());
458  } else {
459  partial_sort(scored.begin(),
460  scored.begin() + ret_num, scored.end(),
461  greater_second());
462  scored.resize(ret_num);
463  }
464 
465  ids.resize(scored.size());
466  for (size_t i = 0; i < scored.size(); ++i) {
467  ids[i].first = key_manager_.get_key(scored[i].first);
468  ids[i].second = scored[i].second;
469  }
470 }
471 
472 const lsh_entry* lsh_index_storage::get_lsh_entry(const string& row) const {
473  lsh_master_table_t::const_iterator it = master_table_diff_.find(row);
474  if (it == master_table_diff_.end()) {
475  it = master_table_.find(row);
476  if (it == master_table_.end()) {
477  return 0;
478  }
479  }
480  return &it->second;
481 }
482 
483 } // namespace storage
484 } // namespace core
485 } // namespace jubatus
uint64_t get_id_const(const std::string &key) const
Definition: key_manager.cpp:67
void get_sorted_similar_rows(const jubatus::util::data::unordered_set< uint64_t > &cands, const bit_vector &query_simhash, float query_norm, uint64_t ret_num, std::vector< std::pair< std::string, float > > &ids) const
bool put_diff(const lsh_master_table_t &mixed_diff)
bit_vector simhash_bv
void get_diff(lsh_master_table_t &diff) const
double dist(const common::sfv_t &p1, const common::sfv_t &p2)
Definition: util.cpp:151
void resize_and_clear(uint64_t bit_num)
Definition: bit_vector.hpp:174
void get_all_row_ids(std::vector< std::string > &ids) const
void pack(framework::packer &packer) const
bool retrieve_hit_rows(uint64_t hash, size_t ret_num, jubatus::util::data::unordered_set< uint64_t > &cands) const
bit_vector binarize(const vector< float > &proj)
float norm
void mix(const lsh_master_table_t &lhs, lsh_master_table_t &rhs) const
jubatus::util::data::unordered_map< std::string, lsh_entry > lsh_master_table_t
Definition: euclid_lsh.hpp:39
std::vector< float > make_entry(const std::vector< float > &hash, float norm, lsh_entry &entry) const
const lsh_entry * get_lsh_entry(const std::string &row) const
const std::string & get_key(const uint64_t id) const
Definition: key_manager.cpp:78
void set_row(const std::string &row, const std::vector< float > &hash, float norm)
msgpack::packer< jubatus_packer > packer
Definition: bandit_base.hpp:31
bit_vector_base< uint64_t > bit_vector
const lsh_vector & base(size_t i) const
Definition: lsh_util.hpp:37
lsh_master_table_t::iterator remove_and_get_row(const std::string &row)
uint64_t get_id(const std::string &key)
Definition: key_manager.cpp:48
void set_mixed_row(const std::string &row, const lsh_entry &entry)
jubatus::util::data::unordered_map< uint64_t, std::vector< uint64_t > > lsh_table_t
void put_empty_entry(uint64_t row_id, const lsh_entry &entry)
std::vector< uint64_t > lsh_hash
void similar_row(const std::vector< float > &hash, float norm, uint64_t probe_num, uint64_t ret_num, std::vector< std::pair< std::string, float > > &ids) const
void remove_model_row(const std::string &row)