jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
nearest_neighbor_classifier.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2014 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
18 
19 #include <cfloat>
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <utility>
24 #include "../storage/column_table.hpp"
25 #include "jubatus/util/concurrent/lock.h"
26 
27 using jubatus::util::lang::shared_ptr;
28 using jubatus::util::data::unordered_set;
29 using jubatus::util::concurrent::scoped_lock;
30 
31 namespace jubatus {
32 namespace core {
33 namespace classifier {
34 
35 namespace {
36 std::string make_id_from_label(const std::string& label,
37  jubatus::util::math::random::mtrand& rand) {
38  const size_t n = 8;
39  std::string result = label;
40  result.reserve(label.size() + 1 + n);
41  result.push_back('_');
42  for (size_t i = 0; i < n; ++i) {
43  int r = rand.next_int(26 * 2 + 10);
44  if (r < 26) {
45  result.push_back('a' + r);
46  } else if (r < 26 * 2) {
47  result.push_back('A' + (r - 26));
48  } else {
49  result.push_back('0' + (r - 26 * 2));
50  }
51  }
52  return result;
53 }
54 
55 std::string get_label_from_id(const std::string& id) {
56  size_t pos = id.find_last_of("_");
57  return id.substr(0, pos);
58 }
59 } // namespace
60 
62  public:
64  : classifier_(classifier) {
65  }
66 
67  void operator()(const std::string& id) {
69  }
70 
71  private:
73 };
74 
76  shared_ptr<nearest_neighbor::nearest_neighbor_base> engine,
77  size_t k,
78  float alpha)
79  : nearest_neighbor_engine_(engine), k_(k), alpha_(alpha) {
80  if (!(alpha >= 0)) {
82  "local_sensitivity should >= 0"));
83  }
84 }
85 
87  const common::sfv_t& fv, const std::string& label) {
88  std::string id;
89  {
90  util::concurrent::scoped_lock lk(rand_mutex_);
91  id = make_id_from_label(label, rand_);
92  }
93  if (unlearner_) {
94  util::concurrent::scoped_lock unlearner_lk(unlearner_mutex_);
95  if (!unlearner_->touch(id)) {
97  "no more space available to add new ID: " + id));
98  }
99  }
100  nearest_neighbor_engine_->set_row(id, fv);
101  set_label(label);
102 }
103 
105  shared_ptr<unlearner::unlearner_base> label_unlearner) {
106  label_unlearner->set_callback(unlearning_callback(this));
107  unlearner_ = label_unlearner;
108 }
109 
111  const common::sfv_t& fv) const {
112  classify_result result;
113  classify_with_scores(fv, result);
114  float max_score = -FLT_MAX;
115  std::string max_class;
116  for (std::vector<classify_result_elem>::const_iterator it = result.begin();
117  it != result.end(); ++it) {
118  if (it == result.begin() || it->score > max_score) {
119  max_score = it->score;
120  max_class = it->label;
121  }
122  }
123  return max_class;
124 }
125 
127  const common::sfv_t& fv, classify_result& scores) const {
128  std::vector<std::pair<std::string, float> > ids;
129  nearest_neighbor_engine_->neighbor_row(fv, ids, k_);
130 
131  std::map<std::string, float> m;
132  for (unordered_set<std::string>::const_iterator iter = labels_.begin();
133  iter != labels_.end(); ++iter) {
134  m.insert(std::make_pair(*iter, 0));
135  }
136  for (size_t i = 0; i < ids.size(); ++i) {
137  std::string label = get_label_from_id(ids[i].first);
138  m[label] += std::exp(-alpha_ * ids[i].second);
139  }
140 
141  scores.clear();
142  for (std::map<std::string, float>::const_iterator iter = m.begin();
143  iter != m.end(); ++iter) {
144  classify_result_elem elem(iter->first, iter->second);
145  scores.push_back(elem);
146  }
147 }
148 
149 bool nearest_neighbor_classifier::delete_label(const std::string& label) {
150  if (labels_.erase(label) == 0) {
151  return false;
152  }
153 
154  shared_ptr<storage::column_table> table =
155  nearest_neighbor_engine_->get_table();
156 
157  std::vector<std::string> ids_to_be_deleted;
158  for (size_t i = 0, n = table->size(); i < n; ++i) {
159  std::string id = table->get_key(i);
160  std::string l = get_label_from_id(id);
161  if (l == label) {
162  ids_to_be_deleted.push_back(id);
163  }
164  }
165 
166  for (size_t i = 0, n = ids_to_be_deleted.size(); i < n; ++i) {
167  const std::string& id = ids_to_be_deleted[i];
168  table->delete_row(id);
169  if (unlearner_) {
170  unlearner_->remove(id);
171  }
172  }
173 
174  return true;
175 }
176 
178  nearest_neighbor_engine_->clear();
179  labels_.clear();
180  if (unlearner_) {
181  unlearner_->clear();
182  }
183 }
184 
185 std::vector<std::string> nearest_neighbor_classifier::get_labels() const {
186  std::vector<std::string> result;
187  for (unordered_set<std::string>::const_iterator iter = labels_.begin();
188  iter != labels_.end(); ++iter) {
189  result.push_back(*iter);
190  }
191  return result;
192 }
193 
194 bool nearest_neighbor_classifier::set_label(const std::string& label) {
195  return labels_.insert(label).second;
196 }
197 
199  return "nearest_neighbor_classifier:" + nearest_neighbor_engine_->type();
200 }
201 
203  std::map<std::string, std::string>& status) const {
204  // unimplemented
205 }
206 
208  pk.pack_array(2);
209  nearest_neighbor_engine_->pack(pk);
210 
211  pk.pack_array(labels_.size());
212  for (unordered_set<std::string>::const_iterator iter = labels_.begin();
213  iter != labels_.end(); ++iter) {
214  pk.pack(*iter);
215  }
216 }
217 
218 void nearest_neighbor_classifier::unpack(msgpack::object o) {
219  if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) {
220  throw msgpack::type_error();
221  }
222  nearest_neighbor_engine_->unpack(o.via.array.ptr[0]);
223 
224  msgpack::object labels = o.via.array.ptr[1];
225  if (labels.type != msgpack::type::ARRAY) {
226  throw msgpack::type_error();
227  }
228  for (size_t i = 0; i < labels.via.array.size; ++i) {
229  std::string label;
230  labels.via.array.ptr[i].convert(&label);
231  labels_.insert(label);
232  }
233 }
234 
236  return nearest_neighbor_engine_->get_mixable();
237 }
238 
239 void nearest_neighbor_classifier::unlearn_id(const std::string& id) {
240  nearest_neighbor_engine_->get_table()->delete_row(id);
241 }
242 
243 } // namespace classifier
244 } // namespace core
245 } // namespace jubatus
std::vector< classify_result_elem > classify_result
jubatus::util::lang::shared_ptr< unlearner::unlearner_base > unlearner_
nearest_neighbor_classifier(jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine, size_t k, float alpha)
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
void get_status(std::map< std::string, std::string > &status) const
void train(const common::sfv_t &fv, const std::string &label)
void set_label_unlearner(jubatus::util::lang::shared_ptr< unlearner::unlearner_base > label_unlearner)
jubatus::util::data::unordered_set< std::string > labels_
void classify_with_scores(const common::sfv_t &fv, classify_result &scores) const
std::vector< std::pair< std::string, float > > sfv_t
Definition: type.hpp:29
jubatus::util::lang::function< void(std::string)> unlearning_callback
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_