jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
Public Member Functions | Private Member Functions | Private Attributes | List of all members
jubatus::core::clustering::gmm Class Reference

#include <gmm.hpp>

Collaboration diagram for jubatus::core::clustering::gmm:
Collaboration graph

Public Member Functions

void batch (const eigen_wsvec_list_t &data, int d, int k)
 
eigen_svec_list_t get_centers ()
 
eigen_smat_list_t get_covs ()
 
eigen_svec_t get_nearest_center (const eigen_svec_t &p) const
 
int64_t get_nearest_center_index (const eigen_svec_t &p) const
 

Private Member Functions

eigen_svec_t cluster_probs (const eigen_svec_t &x, const eigen_svec_list_t &mean, const eigen_smat_list_t &cov, const eigen_solver_list_t &solvers) const
 
void initialize (const eigen_wsvec_list_t &data, int d, int k)
 
bool is_converged (int64_t niter, const eigen_svec_list_t &means, const eigen_svec_list_t &old_means, double obj, double old_obj)
 

Private Attributes

eigen_solver_list_t cov_solvers_
 
eigen_smat_list_t covs_
 
int d_
 
eigen_smat_t eye_
 
int k_
 
eigen_svec_list_t means_
 

Detailed Description

Definition at line 26 of file gmm.hpp.

Member Function Documentation

void jubatus::core::clustering::gmm::batch ( const eigen_wsvec_list_t data,
int  d,
int  k 
)

Definition at line 40 of file gmm.cpp.

References cluster_probs(), cov_solvers_, covs_, eye_, initialize(), is_converged(), and means_.

Referenced by jubatus::core::clustering::gmm_clustering_method::batch_update().

40  {
41  if (data.empty()) {
42  *this = gmm();
43  return;
44  }
45 
46  typedef eigen_wsvec_list_t::const_iterator data_iter;
47  initialize(data, d, k);
48 
49  eigen_svec_list_t old_means;
50  eigen_smat_list_t old_covs;
51  eigen_solver_list_t old_solvers;
52  double old_obj = 0, obj = 0;
53  vector<double> weights(k);
54 
55  bool converged = false;
56  int64_t niter = 1;
57  while (!converged) {
58  old_covs = covs_;
59  old_means = means_;
60  old_solvers = cov_solvers_;
61  old_obj = obj;
62  obj = 0;
63  fill(weights.begin(), weights.end(), 0);
64  fill(means_.begin(), means_.end(), eigen_svec_t(d));
65  fill(covs_.begin(), covs_.end(), eigen_smat_t(d, d));
66 
67  for (data_iter i = data.begin(); i != data.end(); ++i) {
68  eigen_svec_t cps =
69  cluster_probs(i->data, old_means, old_covs, old_solvers);
70  for (int c = 0; c < k; ++c) {
71  double cp = i->weight * cps.coeff(c);
72  means_[c] += cp * i->data;
73  covs_[c] += i->data * (i->data.transpose()) * cp;
74  weights[c] += cp;
75  obj -= std::log(std::max(cp, std::numeric_limits<double>::min()));
76  }
77  }
78  for (int c = 0; c < k; ++c) {
79  means_[c] /= weights[c];
80  covs_[c] /= weights[c];
81  double eps = 0.1;
82  covs_[c] -= means_[c] * means_[c].transpose();
83  covs_[c] += eps * eye_;
84  cov_solvers_[c] =
85  shared_ptr<eigen_solver_t>(new eigen_solver_t(covs_[c]));
86  }
87  converged = is_converged(niter++, means_, old_means, obj, old_obj);
88  }
89 }
Eigen::SimplicialCholesky< eigen_smat_t > eigen_solver_t
Definition: gmm_types.hpp:31
std::vector< jubatus::util::lang::shared_ptr< eigen_solver_t > > eigen_solver_list_t
Definition: gmm_types.hpp:35
eigen_smat_list_t covs_
Definition: gmm.hpp:52
void initialize(const eigen_wsvec_list_t &data, int d, int k)
Definition: gmm.cpp:114
Eigen::SparseMatrix< double > eigen_smat_t
Definition: gmm_types.hpp:30
bool is_converged(int64_t niter, const eigen_svec_list_t &means, const eigen_svec_list_t &old_means, double obj, double old_obj)
Definition: gmm.cpp:136
eigen_svec_t cluster_probs(const eigen_svec_t &x, const eigen_svec_list_t &mean, const eigen_smat_list_t &cov, const eigen_solver_list_t &solvers) const
Definition: gmm.cpp:149
std::vector< eigen_svec_t > eigen_svec_list_t
Definition: gmm_types.hpp:32
eigen_solver_list_t cov_solvers_
Definition: gmm.hpp:54
Eigen::SparseVector< double > eigen_svec_t
Definition: gmm_types.hpp:29
eigen_svec_list_t means_
Definition: gmm.hpp:51
std::vector< eigen_smat_t > eigen_smat_list_t
Definition: gmm_types.hpp:33

Here is the call graph for this function:

Here is the caller graph for this function:

eigen_svec_t jubatus::core::clustering::gmm::cluster_probs ( const eigen_svec_t x,
const eigen_svec_list_t mean,
const eigen_smat_list_t cov,
const eigen_solver_list_t solvers 
) const
private

Definition at line 149 of file gmm.cpp.

References k_, and jubatus::core::clustering::sum().

Referenced by batch(), and get_nearest_center_index().

153  {
154  double den = DBL_MIN;
155  eigen_svec_t ret(k_);
156  for (int i = 0; i < k_; ++i) {
157  eigen_svec_t dif = x - means[i];
158  double det = std::abs(cov_solvers[i]->determinant());
159  double quad = (dif.transpose() * cov_solvers[i]->solve(dif)).sum();
160  double lp = -1 / 2. * (std::log(det) + quad);
161  ret.coeffRef(i) = lp;
162  den = (den == DBL_MIN) ?
163  lp : std::max(den, lp) +
164  std::log(1 + std::exp(min(den, lp) - std::max(den, lp)));
165  }
166  for (int i = 0; i < k_; ++i) {
167  ret.coeffRef(i) = std::exp(ret.coeff(i) - den);
168  }
169  return ret;
170 }
double sum(const common::sfv_t &p)
Definition: util.cpp:47
Eigen::SparseVector< double > eigen_svec_t
Definition: gmm_types.hpp:29

Here is the call graph for this function:

Here is the caller graph for this function:

eigen_svec_list_t jubatus::core::clustering::gmm::get_centers ( )
inline

Definition at line 29 of file gmm.hpp.

References means_.

Referenced by jubatus::core::clustering::gmm_clustering_method::batch_update().

29  {
30  return means_;
31  }
eigen_svec_list_t means_
Definition: gmm.hpp:51

Here is the caller graph for this function:

eigen_smat_list_t jubatus::core::clustering::gmm::get_covs ( )
inline

Definition at line 32 of file gmm.hpp.

References covs_.

32  {
33  return covs_;
34  }
eigen_smat_list_t covs_
Definition: gmm.hpp:52
eigen_svec_t jubatus::core::clustering::gmm::get_nearest_center ( const eigen_svec_t p) const

Definition at line 91 of file gmm.cpp.

References get_nearest_center_index(), and means_.

Referenced by jubatus::core::clustering::gmm_clustering_method::get_nearest_center().

91  {
93 }
int64_t get_nearest_center_index(const eigen_svec_t &p) const
Definition: gmm.cpp:95
eigen_svec_list_t means_
Definition: gmm.hpp:51

Here is the call graph for this function:

Here is the caller graph for this function:

int64_t jubatus::core::clustering::gmm::get_nearest_center_index ( const eigen_svec_t p) const

Definition at line 95 of file gmm.cpp.

References cluster_probs(), cov_solvers_, covs_, JUBATUS_EXCEPTION, k_, and means_.

Referenced by get_nearest_center(), and jubatus::core::clustering::gmm_clustering_method::get_nearest_center_index().

95  {
96  if (means_.empty()) {
97  throw JUBATUS_EXCEPTION(common::exception::runtime_error(
98  "clustering is not performed yet"));
99  }
100 
101  double max_prob = 0;
102  int64_t max_idx = 0;
104  for (int c = 0; c < k_; ++c) {
105  double cp = cps.coeff(c);
106  if (cp > max_prob) {
107  max_prob = cp;
108  max_idx = c;
109  }
110  }
111  return max_idx;
112 }
eigen_smat_list_t covs_
Definition: gmm.hpp:52
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
eigen_svec_t cluster_probs(const eigen_svec_t &x, const eigen_svec_list_t &mean, const eigen_smat_list_t &cov, const eigen_solver_list_t &solvers) const
Definition: gmm.cpp:149
eigen_solver_list_t cov_solvers_
Definition: gmm.hpp:54
Eigen::SparseVector< double > eigen_svec_t
Definition: gmm_types.hpp:29
eigen_svec_list_t means_
Definition: gmm.hpp:51

Here is the call graph for this function:

Here is the caller graph for this function:

void jubatus::core::clustering::gmm::initialize ( const eigen_wsvec_list_t data,
int  d,
int  k 
)
private

Definition at line 114 of file gmm.cpp.

References cov_solvers_, covs_, d_, eye_, k_, and means_.

Referenced by batch().

114  {
115  d_ = d;
116  k_ = k;
120  eye_ = eigen_smat_t(d, d);
121 
122  for (int i = 0; i < d; ++i) {
123  eye_.insert(i, i) = 1;
124  }
125 
126  jubatus::util::math::random::mtrand r(time(NULL));
127  for (int c = 0; c < k; ++c) {
128  means_[c] = data[r.next_int(0, data.size()-1)].data;
129  for (int i = 0; i < d; ++i) {
130  covs_[c].insert(i, i) = 1;
131  }
132  cov_solvers_[c] = shared_ptr<eigen_solver_t>(new eigen_solver_t(covs_[c]));
133  }
134 }
Eigen::SimplicialCholesky< eigen_smat_t > eigen_solver_t
Definition: gmm_types.hpp:31
std::vector< jubatus::util::lang::shared_ptr< eigen_solver_t > > eigen_solver_list_t
Definition: gmm_types.hpp:35
eigen_smat_list_t covs_
Definition: gmm.hpp:52
Eigen::SparseMatrix< double > eigen_smat_t
Definition: gmm_types.hpp:30
std::vector< eigen_svec_t > eigen_svec_list_t
Definition: gmm_types.hpp:32
eigen_solver_list_t cov_solvers_
Definition: gmm.hpp:54
eigen_svec_list_t means_
Definition: gmm.hpp:51
std::vector< eigen_smat_t > eigen_smat_list_t
Definition: gmm_types.hpp:33

Here is the caller graph for this function:

bool jubatus::core::clustering::gmm::is_converged ( int64_t  niter,
const eigen_svec_list_t means,
const eigen_svec_list_t old_means,
double  obj,
double  old_obj 
)
private

Definition at line 136 of file gmm.cpp.

References k_.

Referenced by batch().

141  {
142  double max_dist = 0;
143  for (int c = 0; c < k_; ++c) {
144  max_dist = max(max_dist, (means[c] - old_means[c]).norm());
145  }
146  return (max_dist < 1e-09 || niter > 1e05);
147 }

Here is the caller graph for this function:

Member Data Documentation

eigen_solver_list_t jubatus::core::clustering::gmm::cov_solvers_
private

Definition at line 54 of file gmm.hpp.

Referenced by batch(), get_nearest_center_index(), and initialize().

eigen_smat_list_t jubatus::core::clustering::gmm::covs_
private

Definition at line 52 of file gmm.hpp.

Referenced by batch(), get_covs(), get_nearest_center_index(), and initialize().

int jubatus::core::clustering::gmm::d_
private

Definition at line 55 of file gmm.hpp.

Referenced by initialize().

eigen_smat_t jubatus::core::clustering::gmm::eye_
private

Definition at line 53 of file gmm.hpp.

Referenced by batch(), and initialize().

int jubatus::core::clustering::gmm::k_
private

Definition at line 56 of file gmm.hpp.

Referenced by cluster_probs(), get_nearest_center_index(), initialize(), and is_converged().

eigen_svec_list_t jubatus::core::clustering::gmm::means_
private

Definition at line 51 of file gmm.hpp.

Referenced by batch(), get_centers(), get_nearest_center(), get_nearest_center_index(), and initialize().


The documentation for this class was generated from the following files: