* green/k_mean/k_mean.hh (get_distance) : New accessor method.
* green/k_mean/k_mean.hh (get_group) : New accessor method.
* green/k_mean/k_mean.hh (get_variance) : New accessor method.
* green/k_mean/k_mean.hh (update_variance) : New compute method.
* green/k_mean/k_mean.hh (div_col) : New tool method.
* green/k_mean/k_mean.hh (sum_row) : New tool method.
* green/k_mean/k_mean.hh (k_mean) : Add variance code.
* green/k_mean/k_mean.hh (~k_mean) : Add variance code.
* green/k_mean/k_mean.hh (update_center) : Fix bugs.
* green/k_mean/k_mean.hh (update_distance) : New debugging code.
* green/k_mean/k_mean.cc (#define) : New macros.
* green/k_mean/k_mean.cc (rgb8_to_4colors) : Comment debugging code.
* green/k_mean/k_mean.cc (is_equivalent) : Adapt the signature.
* green/k_mean/k_mean.cc (test_init_point) : Adapt the k_mean call.
* green/k_mean/k_mean.cc (is_equal) : Remove old function.
* green/k_mean/k_mean.cc (set_point) : New assignment func.
* green/k_mean/k_mean.cc (fake_init_point) : New initializat° func.
* green/k_mean/k_mean.cc (is_equal) : New predicate.
* green/k_mean/k_mean.cc (test_init_center) : New unitary test.
* green/k_mean/k_mean.cc (set_center) : New assignment func.
* green/k_mean/k_mean.cc (fake_init_center) : New initializat° func.
* green/k_mean/k_mean.cc (dist) : New distance function.
* green/k_mean/k_mean.cc (test_update_distance): New unitary test.
* green/k_mean/k_mean.cc (set_distance) : New assignment func.
* green/k_mean/k_mean.cc (fake_init_distance) : New initializat° func.
* green/k_mean/k_mean.cc (test_update_group) : New unitary test.
* green/k_mean/k_mean.cc (set_group) : New assignment func.
* green/k_mean/k_mean.cc (fake_update_group) : New initializat° func.
* green/k_mean/k_mean.cc (test_update_center) : New unitary test.
* green/k_mean/k_mean.cc (test_row) : Remove old function.
* green/k_mean/k_mean.cc (test_col) : Remove old function.
* green/k_mean/k_mean.cc (test_update_var) : New unitary test.
* green/k_mean/k_mean.cc (main) : New test calls.
---
trunk/milena/sandbox/ChangeLog | 38 +++
trunk/milena/sandbox/green/k_mean/k_mean.cc | 340 ++++++++++++++++++++------
trunk/milena/sandbox/green/k_mean/k_mean.hh | 123 +++++++++-
3 files changed, 414 insertions(+), 87 deletions(-)
diff --git a/trunk/milena/sandbox/ChangeLog b/trunk/milena/sandbox/ChangeLog
index 2375914..4ba03a5 100644
--- a/trunk/milena/sandbox/ChangeLog
+++ b/trunk/milena/sandbox/ChangeLog
@@ -1,3 +1,41 @@
+2009-08-28 Yann Jacquelet <jacquelet(a)lrde.epita.fr>
+
+ Improve and continue testing the kmean clustering code.
+
+ * green/k_mean/k_mean.hh (get_distance) : New accessor method.
+ * green/k_mean/k_mean.hh (get_group) : New accessor method.
+ * green/k_mean/k_mean.hh (get_variance) : New accessor method.
+ * green/k_mean/k_mean.hh (update_variance) : New compute method.
+ * green/k_mean/k_mean.hh (div_col) : New tool method.
+ * green/k_mean/k_mean.hh (sum_row) : New tool method.
+ * green/k_mean/k_mean.hh (k_mean) : Add variance code.
+ * green/k_mean/k_mean.hh (~k_mean) : Add variance code.
+ * green/k_mean/k_mean.hh (update_center) : Fix bugs.
+ * green/k_mean/k_mean.hh (update_distance) : New debugging code.
+ * green/k_mean/k_mean.cc (#define) : New macros.
+ * green/k_mean/k_mean.cc (rgb8_to_4colors) : Comment debugging code.
+ * green/k_mean/k_mean.cc (is_equivalent) : Adapt the signature.
+ * green/k_mean/k_mean.cc (test_init_point) : Adapt the k_mean call.
+ * green/k_mean/k_mean.cc (is_equal) : Remove old function.
+ * green/k_mean/k_mean.cc (set_point) : New assignment func.
+ * green/k_mean/k_mean.cc (fake_init_point) : New initializat� func.
+ * green/k_mean/k_mean.cc (is_equal) : New predicate.
+ * green/k_mean/k_mean.cc (test_init_center) : New unitary test.
+ * green/k_mean/k_mean.cc (set_center) : New assignment func.
+ * green/k_mean/k_mean.cc (fake_init_center) : New initializat� func.
+ * green/k_mean/k_mean.cc (dist) : New distance function.
+ * green/k_mean/k_mean.cc (test_update_distance): New unitary test.
+ * green/k_mean/k_mean.cc (set_distance) : New assignment func.
+ * green/k_mean/k_mean.cc (fake_init_distance) : New initializat� func.
+ * green/k_mean/k_mean.cc (test_update_group) : New unitary test.
+ * green/k_mean/k_mean.cc (set_group) : New assignment func.
+ * green/k_mean/k_mean.cc (fake_update_group) : New initializat� func.
+ * green/k_mean/k_mean.cc (test_update_center) : New unitary test.
+ * green/k_mean/k_mean.cc (test_row) : Remove old function.
+ * green/k_mean/k_mean.cc (test_col) : Remove old function.
+ * green/k_mean/k_mean.cc (test_update_var) : New unitary test.
+ * green/k_mean/k_mean.cc (main) : New test calls.
+
2009-08-27 Yann Jacquelet <jacquelet(a)lrde.epita.fr>
Improve and test the kmean clustering code.
diff --git a/trunk/milena/sandbox/green/k_mean/k_mean.cc
b/trunk/milena/sandbox/green/k_mean/k_mean.cc
index abef129..91c68d5 100644
--- a/trunk/milena/sandbox/green/k_mean/k_mean.cc
+++ b/trunk/milena/sandbox/green/k_mean/k_mean.cc
@@ -29,6 +29,31 @@
#include <mln/trait/value/print.hh>
#include <mln/trait/image/print.hh>
+#define SIZE_IMAGE 512
+#define SIZE_SAMPLE1 (512*512)
+#define SIZE_SAMPLE2 4
+#define NB_CENTER 2
+#define DIM_POINT 3
+#define TYPE_POINT double
+#define MAT_POINT1 mln::algebra::mat<SIZE_SAMPLE1, DIM_POINT, TYPE_POINT>
+#define MAT_POINT2 mln::algebra::mat<SIZE_SAMPLE2, DIM_POINT, TYPE_POINT>
+#define MAT_CENTER mln::algebra::mat<NB_CENTER, DIM_POINT, TYPE_POINT>
+#define MAT_DISTANCE1 mln::algebra::mat<SIZE_SAMPLE1, NB_CENTER, TYPE_POINT>
+#define MAT_DISTANCE2 mln::algebra::mat<SIZE_SAMPLE2, NB_CENTER, TYPE_POINT>
+#define MAT_GROUP1 mln::algebra::mat<SIZE_SAMPLE1, NB_CENTER, TYPE_POINT>
+#define MAT_GROUP2 mln::algebra::mat<SIZE_SAMPLE2, NB_CENTER, TYPE_POINT>
+#define VEC_VAR mln::algebra::vec<NB_CENTER, TYPE_POINT>
+
+
+void test_instantiation()
+{
+ mln::clustering::k_mean<SIZE_SAMPLE2,NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
+
+ // test the compilation
+
+ std::cout << "test instantiation : ok" << std::endl;
+}
+
struct rgb8_to_4colors : mln::Function_v2v<rgb8_to_4colors>
{
typedef mln::value::rgb8 result;
@@ -67,22 +92,14 @@ void fill_image_with_4colors(mln::image2d<mln::value::rgb8>&
img)
img = mln::data::transform(img, rgb8_to_4colors());
- print_color(lime);
- print_color(brown);
- print_color(teal);
- print_color(purple);
+ //print_color(lime);
+ //print_color(brown);
+ //print_color(teal);
+ //print_color(purple);
}
-#define SIZE_IMAGE 512
-#define SIZE_SAMPLE (512*512)
-#define NB_CENTER 2
-#define DIM_POINT 3
-#define TYPE_POINT double
-#define MAT_POINT mln::algebra::mat<SIZE_SAMPLE, DIM_POINT, TYPE_POINT>
-#define MAT_CENTER mln::algebra::mat<NB_CENTER, DIM_POINT, TYPE_POINT>
-
bool is_equivalent(const mln::image2d<mln::value::rgb8>& img,
- const MAT_POINT& point)
+ const MAT_POINT1& point)
{
mln_piter_(mln::image2d<mln::value::rgb8>) pi(img.domain());
bool result = true;
@@ -123,7 +140,7 @@ void test_init_point()
typedef mln::value::rgb8 rgb8;
mln::image2d<rgb8> img_ref;
- mln::clustering::k_mean<SIZE_SAMPLE,NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
+ mln::clustering::k_mean<SIZE_SAMPLE1,NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
mln::io::ppm::load(img_ref, "/usr/local/share/olena/images/lena.ppm");
//mln::io::ppm::save(img_ref, "verif.ppm");
@@ -133,117 +150,282 @@ void test_init_point()
mln_assertion(true == is_equivalent(img_ref, kmean.get_point()));
- std::cout << "Test init point : ok" << std::endl;
+ std::cout << "Test init point : ok" << std::endl;
}
-bool is_equal(const mln::value::rgb8& ref, const MAT_CENTER& center, unsigned i)
+void set_point(MAT_POINT2& point,
+ const unsigned index,
+ const mln::value::rgb8& color)
{
- bool result = true;
-
- result = result && (ref.red() == center(i, 0));
- result = result && (ref.green() == center(i, 1));
- result = result && (ref.blue() == center(i, 2));
+ point(index,0) = color.red();
+ point(index,1) = color.green();
+ point(index,2) = color.blue();
+}
- return result;
+void fake_init_point(MAT_POINT2& point,
+ const mln::value::rgb8& point1,
+ const mln::value::rgb8& point2,
+ const mln::value::rgb8& point3,
+ const mln::value::rgb8& point4)
+{
+ set_point(point, 0, point1);
+ set_point(point, 1, point2);
+ set_point(point, 2, point3);
+ set_point(point, 3, point4);
}
-bool is_center_initialized(const MAT_CENTER& center)
+bool is_equal(const mln::value::rgb8& ref,
+ const MAT_CENTER& center,
+ const unsigned i)
{
- typedef mln::value::rgb8 rgb8;
- const rgb8 lime = mln::literal::lime;
- const rgb8 brown = mln::literal::brown;
- const rgb8 teal = mln::literal::teal;
- const rgb8 purple = mln::literal::purple;
- bool result = false;
+ bool result = true;
- for (unsigned i = 0; i < NB_CENTER; ++i)
- {
- result = result || is_equal(lime, center, i);
- result = result || is_equal(brown, center, i);
- result = result || is_equal(teal, center, i);
- result = result || is_equal(purple, center, i);
- }
+ result = result && (center(i, 0) - ref.red() < 1.0);
+ result = result && (center(i, 1) - ref.green() < 1.0);
+ result = result && (center(i, 2) - ref.blue() < 1.0);
return result;
}
void test_init_center()
{
- typedef mln::value::rgb8 rgb8;
- mln::image2d<rgb8> img_ref;
+ mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
- mln::clustering::k_mean<SIZE_SAMPLE,NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
+ const mln::value::rgb8 lime = mln::literal::lime;
+ const mln::value::rgb8 brown = mln::literal::brown;
+ const mln::value::rgb8 teal = mln::literal::teal;
+ const mln::value::rgb8 purple = mln::literal::purple;
+
+ fake_init_point(kmean.get_point(), lime, brown, teal, purple);
+ kmean.init_center();
+
+ mln_assertion(is_equal(lime, kmean.get_center(), 0) ||
+ is_equal(brown, kmean.get_center(), 0) ||
+ is_equal(teal, kmean.get_center(), 0) ||
+ is_equal(purple, kmean.get_center(), 0));
- mln::io::ppm::load(img_ref, "/usr/local/share/olena/images/lena.ppm");
+ mln_assertion(is_equal(lime, kmean.get_center(), 1) ||
+ is_equal(brown, kmean.get_center(), 1) ||
+ is_equal(teal, kmean.get_center(), 1) ||
+ is_equal(purple, kmean.get_center(), 1));
- fill_image_with_4colors(img_ref);
- kmean.init_point(img_ref);
- kmean.init_center();
+ std::cout << "Test init center : ok" << std::endl;
+}
- mln_assertion(true == is_center_initialized(kmean.get_center()));
- std::cout << "Test init center : ok" << std::endl;
+void set_center(MAT_CENTER& center,
+ const unsigned index,
+ const mln::value::rgb8& color)
+{
+ center(index,0) = color.red();
+ center(index,1) = color.green();
+ center(index,2) = color.blue();
}
-void test_instantiation()
+void fake_init_center(MAT_CENTER& center,
+ const mln::value::rgb8 center1,
+ const mln::value::rgb8 center2)
{
- std::cout << "test_instantiation" << std::endl;
-
- mln::trace::entering("safe");
- // typedef mln::value::int_u8 int_u8;
- typedef mln::value::rgb8 rgb8;
- const unsigned SIZE = 512*512;
+ set_center(center, 0, center1);
+ set_center(center, 1, center2);
+}
- //mln::image2d<int_u8> img_ref;
- mln::image2d<rgb8> img_ref;
- mln::io::ppm::load(img_ref, "/usr/local/share/olena/images/lena.ppm");
+double dist(mln::value::rgb8 color1, mln::value::rgb8 color2)
+{
+ double red = color1.red() - color2.red();
+ double green = color1.green() - color2.green();
+ double blue = color1.blue() - color2.blue();
+ double result= red * red + green * green + blue * blue;
+
+ return result;
+}
- mln::trace::exiting("safe");
- mln::trace::entering("clustering");
- mln::clustering::k_mean<SIZE,2,3,double> kmean;
+void test_update_distance()
+{
+ mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
- std::cout << img_ref.domain() << std::endl;
- kmean.init_point(img_ref);
- kmean.init_center();
+ const mln::value::rgb8 lime = mln::literal::lime;
+ const mln::value::rgb8 brown = mln::literal::brown;
+ const mln::value::rgb8 teal = mln::literal::teal;
+ const mln::value::rgb8 purple = mln::literal::purple;
+ const mln::value::rgb8 c1 = lime;
+ const mln::value::rgb8 c2 = purple;
+ const MAT_DISTANCE2& dist_v = kmean.get_distance();
+
+ fake_init_point(kmean.get_point(), lime, brown, teal, purple);
+ fake_init_center(kmean.get_center(), c1, c2);
kmean.update_distance();
- kmean.update_group();
- kmean.update_center();
- mln::trace::exiting("clustering");
+
+ mln_assertion(dist(lime, c1) == dist_v(0,0));
+ mln_assertion(dist(lime, c2) == dist_v(0,1));
+ mln_assertion(dist(brown, c1) == dist_v(1,0));
+ mln_assertion(dist(brown, c2) == dist_v(1,1));
+ mln_assertion(dist(teal, c1) == dist_v(2,0));
+ mln_assertion(dist(teal, c2) == dist_v(2,1));
+ mln_assertion(dist(purple, c1) == dist_v(3,0));
+ mln_assertion(dist(purple, c2) == dist_v(3,1));
+
+ std::cout << "Test update distance : ok" << std::endl;
+}
+
+void set_distance(MAT_DISTANCE2& distance,
+ const unsigned index,
+ const double d1,
+ const double d2)
+{
+ distance(index,0) = d1;
+ distance(index,1) = d2;
+}
+
+void fake_update_distance(MAT_DISTANCE2& distance,
+ const mln::value::rgb8& point1,
+ const mln::value::rgb8& point2,
+ const mln::value::rgb8& point3,
+ const mln::value::rgb8& point4,
+ const mln::value::rgb8& center1,
+ const mln::value::rgb8& center2)
+{
+ set_distance(distance, 0, dist(point1, center1), dist(point1, center2));
+ set_distance(distance, 1, dist(point2, center1), dist(point2, center2));
+ set_distance(distance, 2, dist(point3, center1), dist(point3, center2));
+ set_distance(distance, 3, dist(point4, center1), dist(point4, center2));
+
/*
- int_u8 val;
+ for (int i = 0; i < SIZE_SAMPLE2; ++i)
+ for (int j = 0; j < NB_CENTER; ++j)
+ std::cout << "d(" << i << "," << j
<< ") = " << distance(i,j) <<std::endl;
*/
- /*
- mln::trait::value::print<mln::value::int_u8>(std::cout);
- mln::trait::image::print<mln::image2d<int_u8> >();
+}
- mln::image2d<int_u8> img_out;
+void test_update_group()
+{
+ mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
+
+ const mln::value::rgb8 lime = mln::literal::lime;
+ const mln::value::rgb8 brown = mln::literal::brown;
+ const mln::value::rgb8 teal = mln::literal::teal;
+ const mln::value::rgb8 purple = mln::literal::purple;
+ const mln::value::rgb8 c1 = lime;
+ const mln::value::rgb8 c2 = purple;
+ const unsigned point1_min= 0; // lime near lime (c1)
+ const unsigned point2_min= 1; // brown near purple (c2)
+ const unsigned point3_min= 1; // teal near purple (c2)
+ const unsigned point4_min= 1; // purple near purple (c2)
+ const MAT_GROUP2& group = kmean.get_group();
+
+ fake_init_point(kmean.get_point(), lime, brown, teal, purple);
+ fake_init_center(kmean.get_center(), c1, c2);
+ fake_update_distance(kmean.get_distance(), lime, brown, teal, purple, c1, c2);
+ kmean.update_group();
+ mln_assertion(0.0 == group(0, 1 - point1_min));
+ mln_assertion(1.0 == group(0, point1_min));
+ mln_assertion(0.0 == group(1, 1 - point2_min));
+ mln_assertion(1.0 == group(1, point2_min));
+ mln_assertion(0.0 == group(2, 1 - point3_min));
+ mln_assertion(1.0 == group(2, point3_min));
+ mln_assertion(0.0 == group(3, 1 - point4_min));
+ mln_assertion(1.0 == group(3, point4_min));
- //mln::io::pgm::load(img_ref, "mp00082c_50p.pgm");
- mln::trace::exiting("image");
-
+ std::cout << "Test update group : ok" << std::endl;
+}
- mln::trace::entering("test");
+void set_group(MAT_GROUP2& group,
+ const unsigned index,
+ const unsigned min)
+{
+ group(index, min) = 1.0;
+ group(index, 1-min) = 0.0;
+}
- mln::trace::exiting("test");
- // k_mean.distance();
- */
+
+void fake_update_group(MAT_GROUP2& group,
+ const unsigned& point1_min,
+ const unsigned& point2_min,
+ const unsigned& point3_min,
+ const unsigned& point4_min)
+{
+ set_group(group, 0, point1_min);
+ set_group(group, 1, point2_min);
+ set_group(group, 2, point3_min);
+ set_group(group, 3, point4_min);
}
-void test_col()
+void test_update_center()
{
+ mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
+
+ const mln::value::rgb8 lime = mln::literal::lime;
+ const mln::value::rgb8 brown = mln::literal::brown;
+ const mln::value::rgb8 teal = mln::literal::teal;
+ const mln::value::rgb8 purple = mln::literal::purple;
+ const mln::value::rgb8 c1 = lime;
+ const mln::value::rgb8 c2 = purple;
+ const unsigned pt1_min = 0; // lime near lime (c1)
+ const unsigned pt2_min = 1; // brown near purple (c2)
+ const unsigned pt3_min = 1; // teal near purple (c2)
+ const unsigned pt4_min = 1; // purple near purple (c2)
+ const mln::value::rgb8 mean_c1 = lime;
+ const mln::value::rgb8 mean_c2 = (brown+teal+purple)/3;
+
+ fake_init_point(kmean.get_point(), lime, brown, teal, purple);
+ fake_init_center(kmean.get_center(), c1, c2);
+ fake_update_distance(kmean.get_distance(), lime, brown, teal, purple, c1, c2);
+ fake_update_group(kmean.get_group(), pt1_min, pt2_min, pt3_min, pt4_min);
+ kmean.update_center();
+
+ mln_assertion(is_equal(mean_c1, kmean.get_center(), 0));
+ mln_assertion(is_equal(mean_c2, kmean.get_center(), 1));
+
+ std::cout << "Test update center : ok" << std::endl;
}
-void test_row()
+void test_update_var()
{
+ mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean;
+
+ const mln::value::rgb8 lime = mln::literal::lime;
+ const mln::value::rgb8 brown = mln::literal::brown;
+ const mln::value::rgb8 teal = mln::literal::teal;
+ const mln::value::rgb8 purple = mln::literal::purple;
+ const mln::value::rgb8 c1 = lime;
+ const mln::value::rgb8 c2 = purple;
+ const unsigned pt1_min = 0; // lime near lime (c1)
+ const unsigned pt2_min = 1; // brown near purple (c2)
+ const unsigned pt3_min = 1; // teal near purple (c2)
+ const unsigned pt4_min = 1; // purple near purple (c2)
+ const double v1 = 0;
+ const double v2 = dist(purple, brown) + dist(purple, teal);
+ const mln::value::rgb8 mean_c2 = (brown+teal+purple)/3;
+ const VEC_VAR& var = kmean.get_variance();
+
+ fake_init_point(kmean.get_point(), lime, brown, teal, purple);
+ fake_init_center(kmean.get_center(), c1, c2);
+ fake_update_distance(kmean.get_distance(), lime, brown, teal, purple, c1, c2);
+ fake_update_group(kmean.get_group(), pt1_min, pt2_min, pt3_min, pt4_min);
+ kmean.update_variance();
+
+ mln_assertion(v1 == var[0]);
+ mln_assertion(v2 == var[1]);
+
+ std::cout << "Test update variance: ok" << std::endl;
}
+
int main()
{
- //test_instantiation();
+ test_instantiation();
test_init_point();
test_init_center();
+ test_update_distance();
+ test_update_group();
+ test_update_center();
+ test_update_var();
+
+ // mln::trace::quiet = false;
+
+ test_update_center();
return 0;
}
diff --git a/trunk/milena/sandbox/green/k_mean/k_mean.hh
b/trunk/milena/sandbox/green/k_mean/k_mean.hh
index 89547ce..7e295bf 100644
--- a/trunk/milena/sandbox/green/k_mean/k_mean.hh
+++ b/trunk/milena/sandbox/green/k_mean/k_mean.hh
@@ -52,6 +52,7 @@
#include <mln/trace/exiting.hh>
#include <mln/core/contract.hh>
+#include <mln/trait/value_.hh>
#include <mln/algebra/mat.hh>
#include <mln/algebra/vec.hh>
@@ -99,6 +100,9 @@ namespace mln
algebra::mat<n, p, T>& get_point();
algebra::mat<k, p, T>& get_center();
+ algebra::mat<n, k, T>& get_distance();
+ algebra::mat<n, k, T>& get_group();
+ algebra::vec<k, T>& get_variance();
k_mean();
~k_mean();
@@ -106,6 +110,7 @@ namespace mln
void update_distance();
void update_group();
void update_center();
+ void update_variance();
template <unsigned q, typename M>
M min(const algebra::vec<q, M>& x) const;
@@ -138,6 +143,15 @@ namespace mln
algebra::vec<q,T> col(const algebra::mat<r, q, T>& m,
const unsigned _col) const;
+ template <unsigned q, unsigned r>
+ void div_col(algebra::mat<r, q, T>& m,
+ const unsigned _col,
+ const T value);
+
+ template <unsigned q, unsigned r>
+ mln_sum(T) sum_row(const algebra::mat<r, q, T>& m,
+ const unsigned _row) const;
+
private:
/// \brief _points contains the concatenation of every data points.
///
@@ -163,6 +177,8 @@ namespace mln
/// _center is a matrix KxP where K is the number of centers and P is the
/// number of attributes.
algebra::mat<k, p, T>* _center;
+
+ algebra::vec<k, T>* _variance;
};
#ifndef MLN_INCLUDE_ONLY
@@ -190,6 +206,39 @@ namespace mln
}
template <unsigned n, unsigned k, unsigned p, typename T>
+ inline
+ algebra::mat<n, k, T>&
+ k_mean<n,k,p,T>::get_distance()
+ {
+ trace::entering("mln::clustering::k_mean::get_distance");
+ trace::exiting("mln::clustering::k_mean::get_distance");
+
+ return *_distance;
+ }
+
+ template <unsigned n, unsigned k, unsigned p, typename T>
+ inline
+ algebra::mat<n, k, T>&
+ k_mean<n,k,p,T>::get_group()
+ {
+ trace::entering("mln::clustering::k_mean::get_group");
+ trace::exiting("mln::clustering::k_mean::get_group");
+
+ return *_group;
+ }
+
+ template <unsigned n, unsigned k, unsigned p, typename T>
+ inline
+ algebra::vec<k, T>&
+ k_mean<n,k,p,T>::get_variance()
+ {
+ trace::entering("mln::clustering::k_mean::get_variance");
+ trace::exiting("mln::clustering::k_mean::get_variance");
+
+ return *_variance;
+ }
+
+ template <unsigned n, unsigned k, unsigned p, typename T>
k_mean<n,k,p,T>::k_mean()
{
trace::entering("mln::clustering::k_mean::k_mean");
@@ -198,11 +247,13 @@ namespace mln
_distance = new algebra::mat<n, k, mln_sum_product(T,T)>();
_group = new algebra::mat<n, k, T>();
_center = new algebra::mat<k, p, T>();
+ _variance = new algebra::vec<k,T>();
mln_postcondition(_point != 0);
mln_postcondition(_distance != 0);
mln_postcondition(_group != 0);
mln_postcondition(_center != 0);
+ mln_postcondition(_variance != 0);
trace::exiting("mln::clustering::k_mean::k_mean");
}
@@ -216,6 +267,7 @@ namespace mln
delete _distance;
delete _group;
delete _center;
+ delete _variance;
trace::exiting("mln::clustering::k_mean::~k_mean");
}
@@ -273,7 +325,7 @@ namespace mln
center(i,j) = point(random, j);
}
- std::cout << "center(" << i << ")" <<
col(center, i) << std::endl;
+ //std::cout << "center(" << i << ")" <<
col(center, i) << std::endl;
}
trace::exiting("mln::clustering::k_mean<n,k,p,T>::init_center");
@@ -329,12 +381,33 @@ namespace mln
algebra::mat<k, p, T>& center = *_center;
algebra::mat<n, k, T>& group = *_group;
- center = (group.t() * point) / n;
+ center = (group.t() * point);
+
+ for (unsigned i = 0; i < k; ++i)
+ div_col(center, i, sum_row(group, i));
- // mln_postcondition(sum(col(distance(i,j)) == 1) Vi
trace::exiting("mln::clustering::k_mean<n,k,p,T>::update_center");
}
+ template <unsigned n, unsigned k, unsigned p, typename T>
+ inline
+ void k_mean<n,k,p,T>::update_variance()
+ {
+
trace::entering("mln::clustering::k_mean<n,k,p,T>::update_variance");
+
+ algebra::vec<k, T>& variance = *_variance;
+ algebra::mat<n, k, T>& distance = *_distance;
+ algebra::mat<n, k, T>& group = *_group;
+
+ // BUG HERE
+ // separate the n group in n vectors
+ // separate the n distance in n vectors
+ // compute the n scalar product (group, distance)
+ // sum the n scalar product to obtain the within variance
+ variance = (group.t() * distance).t();
+
+
trace::exiting("mln::clustering::k_mean<n,k,p,T>::update_variance");
+ }
/*
@@ -352,7 +425,40 @@ namespace mln
}
*/
- //vec<q, T> k_mean<n,k,p,T>::col(const mat<r, q, T>& m,
unsigned _col) const
+ template <unsigned n, unsigned k, unsigned p, typename T>
+ template <unsigned q, unsigned r>
+ inline
+ mln_sum(T) k_mean<n,k,p,T>::sum_row(const algebra::mat<r, q, T>& m,
+ const unsigned _row) const
+ {
+ trace::entering("mln::clustering::k_mean::sum_row");
+ mln_precondition(q > _row);
+
+ mln_sum(T) result;
+
+ for (unsigned j = 0; j < r; ++j)
+ result += m(j, _row);
+
+ trace::exiting("mln::clustering::k_mean::sum_row");
+ return result;
+ }
+
+ template <unsigned n, unsigned k, unsigned p, typename T>
+ template <unsigned q, unsigned r>
+ inline
+ void k_mean<n,k,p,T>::div_col(algebra::mat<r, q, T>& m,
+ const unsigned _col,
+ const T value)
+ {
+ trace::entering("mln::clustering::k_mean::div_col");
+ mln_precondition(r > _col);
+
+ for (unsigned j = 0; j < q; ++j)
+ m(_col, j) /= value;
+
+ trace::exiting("mln::clustering::k_mean::div_col");
+ }
+
template <unsigned n, unsigned k, unsigned p, typename T>
template <unsigned q, unsigned r>
inline
@@ -390,8 +496,9 @@ namespace mln
inline
void k_mean<n,k,p,T>::update_distance()
{
- trace::entering("mln::clustering::k_mean::distance");
- mln::trace::quiet = true;
+ trace::entering("mln::clustering::k_mean::update_distance");
+ //mln::trace::quiet = true;
+
// the result is stored in _distance matrix.
algebra::mat<n, p, T>& point = *_point;
algebra::mat<n, k, T>& distance = *_distance;
@@ -406,8 +513,8 @@ namespace mln
distance(i,j) = euclidian_distance(col(point,i),col(center,j));
}
}
- mln::trace::quiet = false;
- trace::exiting("mln::clustering::k_mean::distance");
+ //mln::trace::quiet = false;
+ trace::exiting("mln::clustering::k_mean::update_distance");
}
#endif // ! MLN_INCLUDE_ONLY
--
1.5.6.5