jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
Classes | Public Member Functions | Private Member Functions | Private Attributes | List of all members
jubatus::core::classifier::nearest_neighbor_classifier Class Reference

#include <nearest_neighbor_classifier.hpp>

Inheritance diagram for jubatus::core::classifier::nearest_neighbor_classifier:
Inheritance graph
Collaboration diagram for jubatus::core::classifier::nearest_neighbor_classifier:
Collaboration graph

Classes

class  unlearning_callback
 

Public Member Functions

std::string classify (const common::sfv_t &fv) const
 
void classify_with_scores (const common::sfv_t &fv, classify_result &scores) const
 
void clear ()
 
bool delete_label (const std::string &label)
 
std::vector< std::string > get_labels () const
 
framework::mixableget_mixable ()
 
void get_status (std::map< std::string, std::string > &status) const
 
std::string name () const
 
 nearest_neighbor_classifier (jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine, size_t k, float alpha)
 
void pack (framework::packer &pk) const
 
bool set_label (const std::string &label)
 
void set_label_unlearner (jubatus::util::lang::shared_ptr< unlearner::unlearner_base > label_unlearner)
 
void train (const common::sfv_t &fv, const std::string &label)
 
void unpack (msgpack::object o)
 
- Public Member Functions inherited from jubatus::core::classifier::classifier_base
 classifier_base ()
 
virtual ~classifier_base ()
 

Private Member Functions

void unlearn_id (const std::string &id)
 

Private Attributes

float alpha_
 
size_t k_
 
jubatus::util::data::unordered_set< std::string > labels_
 
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_basenearest_neighbor_engine_
 
jubatus::util::math::random::mtrand rand_
 
jubatus::util::concurrent::mutex rand_mutex_
 
jubatus::util::lang::shared_ptr< unlearner::unlearner_baseunlearner_
 
jubatus::util::concurrent::mutex unlearner_mutex_
 

Detailed Description

Definition at line 40 of file nearest_neighbor_classifier.hpp.

Constructor & Destructor Documentation

jubatus::core::classifier::nearest_neighbor_classifier::nearest_neighbor_classifier ( jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base nearest_neighbor_engine,
size_t  k,
float  alpha 
)

Definition at line 75 of file nearest_neighbor_classifier.cpp.

References JUBATUS_EXCEPTION.

79  : nearest_neighbor_engine_(engine), k_(k), alpha_(alpha) {
80  if (!(alpha >= 0)) {
81  throw JUBATUS_EXCEPTION(common::invalid_parameter(
82  "local_sensitivity should >= 0"));
83  }
84 }
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_

Member Function Documentation

std::string jubatus::core::classifier::nearest_neighbor_classifier::classify ( const common::sfv_t fv) const

Definition at line 110 of file nearest_neighbor_classifier.cpp.

References classify_with_scores().

111  {
112  classify_result result;
113  classify_with_scores(fv, result);
114  float max_score = -FLT_MAX;
115  std::string max_class;
116  for (std::vector<classify_result_elem>::const_iterator it = result.begin();
117  it != result.end(); ++it) {
118  if (it == result.begin() || it->score > max_score) {
119  max_score = it->score;
120  max_class = it->label;
121  }
122  }
123  return max_class;
124 }
std::vector< classify_result_elem > classify_result
void classify_with_scores(const common::sfv_t &fv, classify_result &scores) const

Here is the call graph for this function:

void jubatus::core::classifier::nearest_neighbor_classifier::classify_with_scores ( const common::sfv_t fv,
classify_result scores 
) const
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 126 of file nearest_neighbor_classifier.cpp.

References alpha_, k_, labels_, and nearest_neighbor_engine_.

Referenced by classify().

127  {
128  std::vector<std::pair<std::string, float> > ids;
129  nearest_neighbor_engine_->neighbor_row(fv, ids, k_);
130 
131  std::map<std::string, float> m;
132  for (unordered_set<std::string>::const_iterator iter = labels_.begin();
133  iter != labels_.end(); ++iter) {
134  m.insert(std::make_pair(*iter, 0));
135  }
136  for (size_t i = 0; i < ids.size(); ++i) {
137  std::string label = get_label_from_id(ids[i].first);
138  m[label] += std::exp(-alpha_ * ids[i].second);
139  }
140 
141  scores.clear();
142  for (std::map<std::string, float>::const_iterator iter = m.begin();
143  iter != m.end(); ++iter) {
144  classify_result_elem elem(iter->first, iter->second);
145  scores.push_back(elem);
146  }
147 }
jubatus::util::data::unordered_set< std::string > labels_
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_

Here is the caller graph for this function:

void jubatus::core::classifier::nearest_neighbor_classifier::clear ( )
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 177 of file nearest_neighbor_classifier.cpp.

References labels_, nearest_neighbor_engine_, and unlearner_.

177  {
178  nearest_neighbor_engine_->clear();
179  labels_.clear();
180  if (unlearner_) {
181  unlearner_->clear();
182  }
183 }
jubatus::util::lang::shared_ptr< unlearner::unlearner_base > unlearner_
jubatus::util::data::unordered_set< std::string > labels_
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_
bool jubatus::core::classifier::nearest_neighbor_classifier::delete_label ( const std::string &  label)
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 149 of file nearest_neighbor_classifier.cpp.

References labels_, nearest_neighbor_engine_, and unlearner_.

149  {
150  if (labels_.erase(label) == 0) {
151  return false;
152  }
153 
154  shared_ptr<storage::column_table> table =
155  nearest_neighbor_engine_->get_table();
156 
157  std::vector<std::string> ids_to_be_deleted;
158  for (size_t i = 0, n = table->size(); i < n; ++i) {
159  std::string id = table->get_key(i);
160  std::string l = get_label_from_id(id);
161  if (l == label) {
162  ids_to_be_deleted.push_back(id);
163  }
164  }
165 
166  for (size_t i = 0, n = ids_to_be_deleted.size(); i < n; ++i) {
167  const std::string& id = ids_to_be_deleted[i];
168  table->delete_row(id);
169  if (unlearner_) {
170  unlearner_->remove(id);
171  }
172  }
173 
174  return true;
175 }
jubatus::util::lang::shared_ptr< unlearner::unlearner_base > unlearner_
jubatus::util::data::unordered_set< std::string > labels_
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_
std::vector< std::string > jubatus::core::classifier::nearest_neighbor_classifier::get_labels ( ) const
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 185 of file nearest_neighbor_classifier.cpp.

References labels_.

185  {
186  std::vector<std::string> result;
187  for (unordered_set<std::string>::const_iterator iter = labels_.begin();
188  iter != labels_.end(); ++iter) {
189  result.push_back(*iter);
190  }
191  return result;
192 }
jubatus::util::data::unordered_set< std::string > labels_
framework::mixable * jubatus::core::classifier::nearest_neighbor_classifier::get_mixable ( )
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 235 of file nearest_neighbor_classifier.cpp.

References nearest_neighbor_engine_.

235  {
236  return nearest_neighbor_engine_->get_mixable();
237 }
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_
void jubatus::core::classifier::nearest_neighbor_classifier::get_status ( std::map< std::string, std::string > &  status) const
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 202 of file nearest_neighbor_classifier.cpp.

203  {
204  // unimplemented
205 }
std::string jubatus::core::classifier::nearest_neighbor_classifier::name ( ) const
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 198 of file nearest_neighbor_classifier.cpp.

References nearest_neighbor_engine_.

198  {
199  return "nearest_neighbor_classifier:" + nearest_neighbor_engine_->type();
200 }
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_
void jubatus::core::classifier::nearest_neighbor_classifier::pack ( framework::packer pk) const
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 207 of file nearest_neighbor_classifier.cpp.

References labels_, and nearest_neighbor_engine_.

207  {
208  pk.pack_array(2);
209  nearest_neighbor_engine_->pack(pk);
210 
211  pk.pack_array(labels_.size());
212  for (unordered_set<std::string>::const_iterator iter = labels_.begin();
213  iter != labels_.end(); ++iter) {
214  pk.pack(*iter);
215  }
216 }
jubatus::util::data::unordered_set< std::string > labels_
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_
bool jubatus::core::classifier::nearest_neighbor_classifier::set_label ( const std::string &  label)
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 194 of file nearest_neighbor_classifier.cpp.

References labels_.

Referenced by train().

194  {
195  return labels_.insert(label).second;
196 }
jubatus::util::data::unordered_set< std::string > labels_

Here is the caller graph for this function:

void jubatus::core::classifier::nearest_neighbor_classifier::set_label_unlearner ( jubatus::util::lang::shared_ptr< unlearner::unlearner_base label_unlearner)
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 104 of file nearest_neighbor_classifier.cpp.

References unlearner_.

105  {
106  label_unlearner->set_callback(unlearning_callback(this));
107  unlearner_ = label_unlearner;
108 }
jubatus::util::lang::shared_ptr< unlearner::unlearner_base > unlearner_
jubatus::util::lang::function< void(std::string)> unlearning_callback
void jubatus::core::classifier::nearest_neighbor_classifier::train ( const common::sfv_t fv,
const std::string &  label 
)
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 86 of file nearest_neighbor_classifier.cpp.

References JUBATUS_EXCEPTION, nearest_neighbor_engine_, rand_, rand_mutex_, set_label(), unlearner_, and unlearner_mutex_.

87  {
88  std::string id;
89  {
90  util::concurrent::scoped_lock lk(rand_mutex_);
91  id = make_id_from_label(label, rand_);
92  }
93  if (unlearner_) {
94  util::concurrent::scoped_lock unlearner_lk(unlearner_mutex_);
95  if (!unlearner_->touch(id)) {
96  throw JUBATUS_EXCEPTION(common::exception::runtime_error(
97  "no more space available to add new ID: " + id));
98  }
99  }
100  nearest_neighbor_engine_->set_row(id, fv);
101  set_label(label);
102 }
jubatus::util::lang::shared_ptr< unlearner::unlearner_base > unlearner_
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_

Here is the call graph for this function:

void jubatus::core::classifier::nearest_neighbor_classifier::unlearn_id ( const std::string &  id)
private

Definition at line 239 of file nearest_neighbor_classifier.cpp.

References nearest_neighbor_engine_.

Referenced by jubatus::core::classifier::nearest_neighbor_classifier::unlearning_callback::operator()().

239  {
240  nearest_neighbor_engine_->get_table()->delete_row(id);
241 }
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_

Here is the caller graph for this function:

void jubatus::core::classifier::nearest_neighbor_classifier::unpack ( msgpack::object  o)
virtual

Implements jubatus::core::classifier::classifier_base.

Definition at line 218 of file nearest_neighbor_classifier.cpp.

References labels_, and nearest_neighbor_engine_.

218  {
219  if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) {
220  throw msgpack::type_error();
221  }
222  nearest_neighbor_engine_->unpack(o.via.array.ptr[0]);
223 
224  msgpack::object labels = o.via.array.ptr[1];
225  if (labels.type != msgpack::type::ARRAY) {
226  throw msgpack::type_error();
227  }
228  for (size_t i = 0; i < labels.via.array.size; ++i) {
229  std::string label;
230  labels.via.array.ptr[i].convert(&label);
231  labels_.insert(label);
232  }
233 }
jubatus::util::data::unordered_set< std::string > labels_
jubatus::util::lang::shared_ptr< nearest_neighbor::nearest_neighbor_base > nearest_neighbor_engine_

Member Data Documentation

float jubatus::core::classifier::nearest_neighbor_classifier::alpha_
private

Definition at line 76 of file nearest_neighbor_classifier.hpp.

Referenced by classify_with_scores().

size_t jubatus::core::classifier::nearest_neighbor_classifier::k_
private

Definition at line 75 of file nearest_neighbor_classifier.hpp.

Referenced by classify_with_scores().

jubatus::util::data::unordered_set<std::string> jubatus::core::classifier::nearest_neighbor_classifier::labels_
private
jubatus::util::lang::shared_ptr<nearest_neighbor::nearest_neighbor_base> jubatus::core::classifier::nearest_neighbor_classifier::nearest_neighbor_engine_
private
jubatus::util::math::random::mtrand jubatus::core::classifier::nearest_neighbor_classifier::rand_
private

Definition at line 80 of file nearest_neighbor_classifier.hpp.

Referenced by train().

jubatus::util::concurrent::mutex jubatus::core::classifier::nearest_neighbor_classifier::rand_mutex_
private

Definition at line 79 of file nearest_neighbor_classifier.hpp.

Referenced by train().

jubatus::util::lang::shared_ptr<unlearner::unlearner_base> jubatus::core::classifier::nearest_neighbor_classifier::unlearner_
private

Definition at line 78 of file nearest_neighbor_classifier.hpp.

Referenced by clear(), delete_label(), set_label_unlearner(), and train().

jubatus::util::concurrent::mutex jubatus::core::classifier::nearest_neighbor_classifier::unlearner_mutex_
private

Definition at line 77 of file nearest_neighbor_classifier.hpp.

Referenced by train().


The documentation for this class was generated from the following files: