jubatus_core  0.1.2
Jubatus: Online machine learning framework for distributed environment
classifier_factory.cpp
Go to the documentation of this file.
1 // Jubatus: Online machine learning framework for distributed environment
2 // Copyright (C) 2011 Preferred Networks and Nippon Telegraph and Telephone Corporation.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License version 2.1 as published by the Free Software Foundation.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 
17 #include "classifier_factory.hpp"
18 
19 #include <string>
20 
21 #include "classifier.hpp"
22 #include "../common/exception.hpp"
23 #include "../common/jsonconfig.hpp"
24 #include "../storage/storage_base.hpp"
25 #include "../unlearner/unlearner_factory.hpp"
26 #include "../nearest_neighbor/nearest_neighbor_factory.hpp"
27 
30 using jubatus::util::lang::shared_ptr;
31 
32 namespace jubatus {
33 namespace core {
34 namespace classifier {
35 namespace {
36 
37 struct unlearner_config {
38  jubatus::util::data::optional<std::string> unlearner;
39  jubatus::util::data::optional<config> unlearner_parameter;
40 
41  template<typename Ar>
42  void serialize(Ar& ar) {
43  ar & JUBA_MEMBER(unlearner) & JUBA_MEMBER(unlearner_parameter);
44  }
45 };
46 
47 struct unlearning_classifier_config
48  : public classifier_config, unlearner_config {
49  template<typename Ar>
50  void serialize(Ar& ar) {
53  }
54 };
55 
56 struct nearest_neighbor_classifier_config
57  : public unlearner_config {
58  std::string method;
62 
63  template<typename Ar>
64  void serialize(Ar& ar) {
65  ar & JUBA_MEMBER(method)
66  & JUBA_MEMBER(parameter)
67  & JUBA_MEMBER(nearest_neighbor_num)
68  & JUBA_MEMBER(local_sensitivity);
70  }
71 };
72 
73 jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
74 create_unlearner(const unlearner_config& conf) {
75  if (conf.unlearner) {
76  if (!conf.unlearner_parameter) {
77  throw JUBATUS_EXCEPTION(common::exception::runtime_error(
78  "Unlearner is set but unlearner_parameter is not found"));
79  }
81  *conf.unlearner, *conf.unlearner_parameter);
82  } else {
83  return jubatus::util::lang::shared_ptr<unlearner::unlearner_base>();
84  }
85 }
86 
87 } // namespace
88 
89 shared_ptr<classifier_base> classifier_factory::create_classifier(
90  const std::string& name,
91  const common::jsonconfig::config& param,
92  jubatus::util::lang::shared_ptr<storage::storage_base> storage) {
93  jubatus::util::lang::shared_ptr<unlearner::unlearner_base> unlearner;
94  shared_ptr<classifier_base> res;
95  if (name == "perceptron") {
96  // perceptron doesn't have parameter
97  if (param.type() != jubatus::util::text::json::json::Null) {
98  unlearner_config conf = config_cast_check<unlearner_config>(param);
99  unlearner = create_unlearner(conf);
100  }
101  res.reset(new perceptron(storage));
102  } else if (name == "PA" || name == "passive_aggressive") {
103  // passive_aggressive doesn't have parameter
104  if (param.type() != jubatus::util::text::json::json::Null) {
105  unlearner_config conf = config_cast_check<unlearner_config>(param);
106  unlearner = create_unlearner(conf);
107  }
108  res.reset(new passive_aggressive(storage));
109  } else if (name == "PA1" || name == "passive_aggressive_1") {
110  if (param.type() == jubatus::util::text::json::json::Null) {
111  throw JUBATUS_EXCEPTION(
113  "parameter block is not specified in config"));
114  }
115  unlearning_classifier_config conf
116  = config_cast_check<unlearning_classifier_config>(param);
117  unlearner = create_unlearner(conf);
118  res.reset(new passive_aggressive_1(conf, storage));
119  } else if (name == "PA2" || name == "passive_aggressive_2") {
120  if (param.type() == jubatus::util::text::json::json::Null) {
121  throw JUBATUS_EXCEPTION(
123  "parameter block is not specified in config"));
124  }
125  unlearning_classifier_config conf
126  = config_cast_check<unlearning_classifier_config>(param);
127  unlearner = create_unlearner(conf);
128  res.reset(new passive_aggressive_2(conf, storage));
129  } else if (name == "CW" || name == "confidence_weighted") {
130  if (param.type() == jubatus::util::text::json::json::Null) {
131  throw JUBATUS_EXCEPTION(
133  "parameter block is not specified in config"));
134  }
135  unlearning_classifier_config conf
136  = config_cast_check<unlearning_classifier_config>(param);
137  unlearner = create_unlearner(conf);
138  res.reset(new confidence_weighted(conf, storage));
139  } else if (name == "AROW" || name == "arow") {
140  if (param.type() == jubatus::util::text::json::json::Null) {
141  throw JUBATUS_EXCEPTION(
143  "parameter block is not specified in config"));
144  }
145  unlearning_classifier_config conf
146  = config_cast_check<unlearning_classifier_config>(param);
147  unlearner = create_unlearner(conf);
148  res.reset(new arow(conf, storage));
149  } else if (name == "NHERD" || name == "normal_herd") {
150  if (param.type() == jubatus::util::text::json::json::Null) {
151  throw JUBATUS_EXCEPTION(
153  "parameter block is not specified in config"));
154  }
155  unlearning_classifier_config conf
156  = config_cast_check<unlearning_classifier_config>(param);
157  unlearner = create_unlearner(conf);
158  res.reset(new normal_herd(conf, storage));
159  } else if (name == "NN" || name == "nearest_neighbor") {
160  if (param.type() == jubatus::util::text::json::json::Null) {
161  throw JUBATUS_EXCEPTION(
163  "parameter block is not specified in config"));
164  }
165  nearest_neighbor_classifier_config conf
166  = config_cast_check<nearest_neighbor_classifier_config>(param);
167  unlearner = create_unlearner(conf);
168  shared_ptr<storage::column_table> table(new storage::column_table);
169  shared_ptr<nearest_neighbor::nearest_neighbor_base>
170  nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor(
171  conf.method, conf.parameter, table, ""));
172  res.reset(
173  new nearest_neighbor_classifier(nearest_neighbor_engine,
174  conf.nearest_neighbor_num,
175  conf.local_sensitivity));
176  } else {
177  throw JUBATUS_EXCEPTION(
178  common::unsupported_method("classifier(" + name + ")"));
179  }
180 
181  if (unlearner) {
182  res->set_label_unlearner(unlearner);
183  }
184  return res;
185 }
186 
187 } // namespace classifier
188 } // namespace core
189 } // namespace jubatus
T config_cast_check(const config &c)
Definition: cast.hpp:311
void serialize(member_collector &mem, jubatus::util::data::serialization::named_value< T > &v)
Definition: cast.hpp:54
static jubatus::util::lang::shared_ptr< classifier_base > create_classifier(const std::string &name, const common::jsonconfig::config &param, jubatus::util::lang::shared_ptr< storage::storage_base > storage)
shared_ptr< unlearner_base > create_unlearner(const std::string &name, const common::jsonconfig::config &config)
jubatus::util::data::optional< std::string > unlearner
config parameter
std::string method
jubatus::util::text::json::json::json_type_t type() const
Definition: config.hpp:84
int nearest_neighbor_num
#define JUBATUS_EXCEPTION(e)
Definition: exception.hpp:79
jubatus::util::data::optional< config > unlearner_parameter
shared_ptr< nearest_neighbor_base > create_nearest_neighbor(const std::string &name, const common::jsonconfig::config &config, shared_ptr< storage::column_table > table, const std::string &id)
float local_sensitivity