jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
kmeans_clustering_method.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2013 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
18 
19 #include <iostream>
20 #include <utility>
21 #include <vector>
22 #include "../common/exception.hpp"
23 #include "util.hpp"
24 
25 using std::pair;
26 using std::vector;
27 
28 namespace jubatus {
29 namespace core {
30 namespace clustering {
31 
33  : k_(k) {
34 }
35 
37 }
38 
40  if (points.empty()) {
41  kcenters_.clear();
42  return;
43  }
44  initialize_centers(points);
45  do_batch_update(points);
46 }
47 
49  if (points.size() < k_) {
50  return;
51  }
52  kcenters_.clear();
53  kcenters_.push_back(points[0].data);
54  vector<double> weights;
55  while (kcenters_.size() < k_) {
56  weights.clear();
57  for (wplist::iterator it = points.begin(); it != points.end(); ++it) {
58  pair<int64_t, double> m = min_dist((*it).data, kcenters_);
59  weights.push_back(m.second * it->weight);
60  }
61  discrete_distribution d(weights.begin(), weights.end());
62  kcenters_.push_back(points[d()].data);
63  }
64 }
65 
67  static jubatus::util::math::random::mtrand r;
68  bool terminated = false;
69  if (points.size() < k_) {
70  return;
71  }
72  while (!terminated) {
73  vector<common::sfv_t> kcenters_new(k_);
74  vector<double> center_count(k_, 0);
75  for (wplist::iterator it = points.begin(); it != points.end(); ++it) {
76  pair<int64_t, double> m = min_dist((*it).data, kcenters_);
77  scalar_mul_and_add(it->data, it->weight, kcenters_new[m.first]);
78  center_count[m.first] += it->weight;
79  }
80  terminated = true;
81  for (size_t i = 0; i < k_; ++i) {
82  if (center_count[i] == 0) {
83  kcenters_new[i] = kcenters_[i];
84  continue;
85  }
86  kcenters_new[i] = scalar_dot(kcenters_new[i], 1.0 / center_count[i]);
87  double d = dist(kcenters_new[i], kcenters_[i]);
88  if (d > 1e-9) {
89  terminated = false;
90  }
91  }
92  kcenters_ = kcenters_new;
93  }
94 }
95 
97 }
98 
99 vector<common::sfv_t> kmeans_clustering_method::get_k_center() const {
100  return kcenters_;
101 }
102 
104  const common::sfv_t& point) const {
105  return min_dist(point, kcenters_).first;
106 }
107 
109  const common::sfv_t& point) const {
110  return kcenters_[get_nearest_center_index(point)];
111 }
112 
114  size_t cluster_id,
115  const wplist& points) const {
116  if (cluster_id >= k_) {
117  return wplist();
118  }
119  return get_clusters(points)[cluster_id];
120 }
121 
123  const wplist& points) const {
124  vector<wplist> ret(k_);
125  for (wplist::const_iterator it = points.begin(); it != points.end(); ++it) {
126  pair<int64_t, double> m = min_dist(it->data, kcenters_);
127  ret[m.first].push_back(*it);
128  }
129  return ret;
130 }
131 
132 } // namespace clustering
133 } // namespace core
134 } // namespace jubatus
std::vector< wplist > get_clusters(const wplist &points) const
void scalar_mul_and_add(const common::sfv_t &left, float s, common::sfv_t &right)
Definition: util.cpp:62
common::sfv_t scalar_dot(const common::sfv_t &p, double s)
Definition: util.cpp:143
double dist(const common::sfv_t &p1, const common::sfv_t &p2)
Definition: util.cpp:151
common::sfv_t get_nearest_center(const common::sfv_t &point) const
wplist get_cluster(size_t cluster_id, const wplist &points) const
pair< size_t, double > min_dist(const common::sfv_t &p, const vector< common::sfv_t > &P)
Definition: util.cpp:182
int64_t get_nearest_center_index(const common::sfv_t &point) const
std::vector< std::pair< std::string, float > > sfv_t
Definition: type.hpp:29
std::vector< weighted_point > wplist
Definition: types.hpp:55