
* green/k_mean/k_mean.hh (get_distance) : New accessor method. * green/k_mean/k_mean.hh (get_group) : New accessor method. * green/k_mean/k_mean.hh (get_variance) : New accessor method. * green/k_mean/k_mean.hh (update_variance) : New compute method. * green/k_mean/k_mean.hh (div_col) : New tool method. * green/k_mean/k_mean.hh (sum_row) : New tool method. * green/k_mean/k_mean.hh (k_mean) : Add variance code. * green/k_mean/k_mean.hh (~k_mean) : Add variance code. * green/k_mean/k_mean.hh (update_center) : Fix bugs. * green/k_mean/k_mean.hh (update_distance) : New debugging code. * green/k_mean/k_mean.cc (#define) : New macros. * green/k_mean/k_mean.cc (rgb8_to_4colors) : Comment debugging code. * green/k_mean/k_mean.cc (is_equivalent) : Adapt the signature. * green/k_mean/k_mean.cc (test_init_point) : Adapt the k_mean call. * green/k_mean/k_mean.cc (is_equal) : Remove old function. * green/k_mean/k_mean.cc (set_point) : New assignment func. * green/k_mean/k_mean.cc (fake_init_point) : New initializat° func. * green/k_mean/k_mean.cc (is_equal) : New predicate. * green/k_mean/k_mean.cc (test_init_center) : New unitary test. * green/k_mean/k_mean.cc (set_center) : New assignment func. * green/k_mean/k_mean.cc (fake_init_center) : New initializat° func. * green/k_mean/k_mean.cc (dist) : New distance function. * green/k_mean/k_mean.cc (test_update_distance): New unitary test. * green/k_mean/k_mean.cc (set_distance) : New assignment func. * green/k_mean/k_mean.cc (fake_init_distance) : New initializat° func. * green/k_mean/k_mean.cc (test_update_group) : New unitary test. * green/k_mean/k_mean.cc (set_group) : New assignment func. * green/k_mean/k_mean.cc (fake_update_group) : New initializat° func. * green/k_mean/k_mean.cc (test_update_center) : New unitary test. * green/k_mean/k_mean.cc (test_row) : Remove old function. * green/k_mean/k_mean.cc (test_col) : Remove old function. * green/k_mean/k_mean.cc (test_update_var) : New unitary test. * green/k_mean/k_mean.cc (main) : New test calls. --- trunk/milena/sandbox/ChangeLog | 38 +++ trunk/milena/sandbox/green/k_mean/k_mean.cc | 340 ++++++++++++++++++++------ trunk/milena/sandbox/green/k_mean/k_mean.hh | 123 +++++++++- 3 files changed, 414 insertions(+), 87 deletions(-) diff --git a/trunk/milena/sandbox/ChangeLog b/trunk/milena/sandbox/ChangeLog index 2375914..4ba03a5 100644 --- a/trunk/milena/sandbox/ChangeLog +++ b/trunk/milena/sandbox/ChangeLog @@ -1,3 +1,41 @@ +2009-08-28 Yann Jacquelet <jacquelet@lrde.epita.fr> + + Improve and continue testing the kmean clustering code. + + * green/k_mean/k_mean.hh (get_distance) : New accessor method. + * green/k_mean/k_mean.hh (get_group) : New accessor method. + * green/k_mean/k_mean.hh (get_variance) : New accessor method. + * green/k_mean/k_mean.hh (update_variance) : New compute method. + * green/k_mean/k_mean.hh (div_col) : New tool method. + * green/k_mean/k_mean.hh (sum_row) : New tool method. + * green/k_mean/k_mean.hh (k_mean) : Add variance code. + * green/k_mean/k_mean.hh (~k_mean) : Add variance code. + * green/k_mean/k_mean.hh (update_center) : Fix bugs. + * green/k_mean/k_mean.hh (update_distance) : New debugging code. + * green/k_mean/k_mean.cc (#define) : New macros. + * green/k_mean/k_mean.cc (rgb8_to_4colors) : Comment debugging code. + * green/k_mean/k_mean.cc (is_equivalent) : Adapt the signature. + * green/k_mean/k_mean.cc (test_init_point) : Adapt the k_mean call. + * green/k_mean/k_mean.cc (is_equal) : Remove old function. + * green/k_mean/k_mean.cc (set_point) : New assignment func. + * green/k_mean/k_mean.cc (fake_init_point) : New initializat� func. + * green/k_mean/k_mean.cc (is_equal) : New predicate. + * green/k_mean/k_mean.cc (test_init_center) : New unitary test. + * green/k_mean/k_mean.cc (set_center) : New assignment func. + * green/k_mean/k_mean.cc (fake_init_center) : New initializat� func. + * green/k_mean/k_mean.cc (dist) : New distance function. + * green/k_mean/k_mean.cc (test_update_distance): New unitary test. + * green/k_mean/k_mean.cc (set_distance) : New assignment func. + * green/k_mean/k_mean.cc (fake_init_distance) : New initializat� func. + * green/k_mean/k_mean.cc (test_update_group) : New unitary test. + * green/k_mean/k_mean.cc (set_group) : New assignment func. + * green/k_mean/k_mean.cc (fake_update_group) : New initializat� func. + * green/k_mean/k_mean.cc (test_update_center) : New unitary test. + * green/k_mean/k_mean.cc (test_row) : Remove old function. + * green/k_mean/k_mean.cc (test_col) : Remove old function. + * green/k_mean/k_mean.cc (test_update_var) : New unitary test. + * green/k_mean/k_mean.cc (main) : New test calls. + 2009-08-27 Yann Jacquelet <jacquelet@lrde.epita.fr> Improve and test the kmean clustering code. diff --git a/trunk/milena/sandbox/green/k_mean/k_mean.cc b/trunk/milena/sandbox/green/k_mean/k_mean.cc index abef129..91c68d5 100644 --- a/trunk/milena/sandbox/green/k_mean/k_mean.cc +++ b/trunk/milena/sandbox/green/k_mean/k_mean.cc @@ -29,6 +29,31 @@ #include <mln/trait/value/print.hh> #include <mln/trait/image/print.hh> +#define SIZE_IMAGE 512 +#define SIZE_SAMPLE1 (512*512) +#define SIZE_SAMPLE2 4 +#define NB_CENTER 2 +#define DIM_POINT 3 +#define TYPE_POINT double +#define MAT_POINT1 mln::algebra::mat<SIZE_SAMPLE1, DIM_POINT, TYPE_POINT> +#define MAT_POINT2 mln::algebra::mat<SIZE_SAMPLE2, DIM_POINT, TYPE_POINT> +#define MAT_CENTER mln::algebra::mat<NB_CENTER, DIM_POINT, TYPE_POINT> +#define MAT_DISTANCE1 mln::algebra::mat<SIZE_SAMPLE1, NB_CENTER, TYPE_POINT> +#define MAT_DISTANCE2 mln::algebra::mat<SIZE_SAMPLE2, NB_CENTER, TYPE_POINT> +#define MAT_GROUP1 mln::algebra::mat<SIZE_SAMPLE1, NB_CENTER, TYPE_POINT> +#define MAT_GROUP2 mln::algebra::mat<SIZE_SAMPLE2, NB_CENTER, TYPE_POINT> +#define VEC_VAR mln::algebra::vec<NB_CENTER, TYPE_POINT> + + +void test_instantiation() +{ + mln::clustering::k_mean<SIZE_SAMPLE2,NB_CENTER, DIM_POINT, TYPE_POINT> kmean; + + // test the compilation + + std::cout << "test instantiation : ok" << std::endl; +} + struct rgb8_to_4colors : mln::Function_v2v<rgb8_to_4colors> { typedef mln::value::rgb8 result; @@ -67,22 +92,14 @@ void fill_image_with_4colors(mln::image2d<mln::value::rgb8>& img) img = mln::data::transform(img, rgb8_to_4colors()); - print_color(lime); - print_color(brown); - print_color(teal); - print_color(purple); + //print_color(lime); + //print_color(brown); + //print_color(teal); + //print_color(purple); } -#define SIZE_IMAGE 512 -#define SIZE_SAMPLE (512*512) -#define NB_CENTER 2 -#define DIM_POINT 3 -#define TYPE_POINT double -#define MAT_POINT mln::algebra::mat<SIZE_SAMPLE, DIM_POINT, TYPE_POINT> -#define MAT_CENTER mln::algebra::mat<NB_CENTER, DIM_POINT, TYPE_POINT> - bool is_equivalent(const mln::image2d<mln::value::rgb8>& img, - const MAT_POINT& point) + const MAT_POINT1& point) { mln_piter_(mln::image2d<mln::value::rgb8>) pi(img.domain()); bool result = true; @@ -123,7 +140,7 @@ void test_init_point() typedef mln::value::rgb8 rgb8; mln::image2d<rgb8> img_ref; - mln::clustering::k_mean<SIZE_SAMPLE,NB_CENTER, DIM_POINT, TYPE_POINT> kmean; + mln::clustering::k_mean<SIZE_SAMPLE1,NB_CENTER, DIM_POINT, TYPE_POINT> kmean; mln::io::ppm::load(img_ref, "/usr/local/share/olena/images/lena.ppm"); //mln::io::ppm::save(img_ref, "verif.ppm"); @@ -133,117 +150,282 @@ void test_init_point() mln_assertion(true == is_equivalent(img_ref, kmean.get_point())); - std::cout << "Test init point : ok" << std::endl; + std::cout << "Test init point : ok" << std::endl; } -bool is_equal(const mln::value::rgb8& ref, const MAT_CENTER& center, unsigned i) +void set_point(MAT_POINT2& point, + const unsigned index, + const mln::value::rgb8& color) { - bool result = true; - - result = result && (ref.red() == center(i, 0)); - result = result && (ref.green() == center(i, 1)); - result = result && (ref.blue() == center(i, 2)); + point(index,0) = color.red(); + point(index,1) = color.green(); + point(index,2) = color.blue(); +} - return result; +void fake_init_point(MAT_POINT2& point, + const mln::value::rgb8& point1, + const mln::value::rgb8& point2, + const mln::value::rgb8& point3, + const mln::value::rgb8& point4) +{ + set_point(point, 0, point1); + set_point(point, 1, point2); + set_point(point, 2, point3); + set_point(point, 3, point4); } -bool is_center_initialized(const MAT_CENTER& center) +bool is_equal(const mln::value::rgb8& ref, + const MAT_CENTER& center, + const unsigned i) { - typedef mln::value::rgb8 rgb8; - const rgb8 lime = mln::literal::lime; - const rgb8 brown = mln::literal::brown; - const rgb8 teal = mln::literal::teal; - const rgb8 purple = mln::literal::purple; - bool result = false; + bool result = true; - for (unsigned i = 0; i < NB_CENTER; ++i) - { - result = result || is_equal(lime, center, i); - result = result || is_equal(brown, center, i); - result = result || is_equal(teal, center, i); - result = result || is_equal(purple, center, i); - } + result = result && (center(i, 0) - ref.red() < 1.0); + result = result && (center(i, 1) - ref.green() < 1.0); + result = result && (center(i, 2) - ref.blue() < 1.0); return result; } void test_init_center() { - typedef mln::value::rgb8 rgb8; - mln::image2d<rgb8> img_ref; + mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean; - mln::clustering::k_mean<SIZE_SAMPLE,NB_CENTER, DIM_POINT, TYPE_POINT> kmean; + const mln::value::rgb8 lime = mln::literal::lime; + const mln::value::rgb8 brown = mln::literal::brown; + const mln::value::rgb8 teal = mln::literal::teal; + const mln::value::rgb8 purple = mln::literal::purple; + + fake_init_point(kmean.get_point(), lime, brown, teal, purple); + kmean.init_center(); + + mln_assertion(is_equal(lime, kmean.get_center(), 0) || + is_equal(brown, kmean.get_center(), 0) || + is_equal(teal, kmean.get_center(), 0) || + is_equal(purple, kmean.get_center(), 0)); - mln::io::ppm::load(img_ref, "/usr/local/share/olena/images/lena.ppm"); + mln_assertion(is_equal(lime, kmean.get_center(), 1) || + is_equal(brown, kmean.get_center(), 1) || + is_equal(teal, kmean.get_center(), 1) || + is_equal(purple, kmean.get_center(), 1)); - fill_image_with_4colors(img_ref); - kmean.init_point(img_ref); - kmean.init_center(); + std::cout << "Test init center : ok" << std::endl; +} - mln_assertion(true == is_center_initialized(kmean.get_center())); - std::cout << "Test init center : ok" << std::endl; +void set_center(MAT_CENTER& center, + const unsigned index, + const mln::value::rgb8& color) +{ + center(index,0) = color.red(); + center(index,1) = color.green(); + center(index,2) = color.blue(); } -void test_instantiation() +void fake_init_center(MAT_CENTER& center, + const mln::value::rgb8 center1, + const mln::value::rgb8 center2) { - std::cout << "test_instantiation" << std::endl; - - mln::trace::entering("safe"); - // typedef mln::value::int_u8 int_u8; - typedef mln::value::rgb8 rgb8; - const unsigned SIZE = 512*512; + set_center(center, 0, center1); + set_center(center, 1, center2); +} - //mln::image2d<int_u8> img_ref; - mln::image2d<rgb8> img_ref; - mln::io::ppm::load(img_ref, "/usr/local/share/olena/images/lena.ppm"); +double dist(mln::value::rgb8 color1, mln::value::rgb8 color2) +{ + double red = color1.red() - color2.red(); + double green = color1.green() - color2.green(); + double blue = color1.blue() - color2.blue(); + double result= red * red + green * green + blue * blue; + + return result; +} - mln::trace::exiting("safe"); - mln::trace::entering("clustering"); - mln::clustering::k_mean<SIZE,2,3,double> kmean; +void test_update_distance() +{ + mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean; - std::cout << img_ref.domain() << std::endl; - kmean.init_point(img_ref); - kmean.init_center(); + const mln::value::rgb8 lime = mln::literal::lime; + const mln::value::rgb8 brown = mln::literal::brown; + const mln::value::rgb8 teal = mln::literal::teal; + const mln::value::rgb8 purple = mln::literal::purple; + const mln::value::rgb8 c1 = lime; + const mln::value::rgb8 c2 = purple; + const MAT_DISTANCE2& dist_v = kmean.get_distance(); + + fake_init_point(kmean.get_point(), lime, brown, teal, purple); + fake_init_center(kmean.get_center(), c1, c2); kmean.update_distance(); - kmean.update_group(); - kmean.update_center(); - mln::trace::exiting("clustering"); + + mln_assertion(dist(lime, c1) == dist_v(0,0)); + mln_assertion(dist(lime, c2) == dist_v(0,1)); + mln_assertion(dist(brown, c1) == dist_v(1,0)); + mln_assertion(dist(brown, c2) == dist_v(1,1)); + mln_assertion(dist(teal, c1) == dist_v(2,0)); + mln_assertion(dist(teal, c2) == dist_v(2,1)); + mln_assertion(dist(purple, c1) == dist_v(3,0)); + mln_assertion(dist(purple, c2) == dist_v(3,1)); + + std::cout << "Test update distance : ok" << std::endl; +} + +void set_distance(MAT_DISTANCE2& distance, + const unsigned index, + const double d1, + const double d2) +{ + distance(index,0) = d1; + distance(index,1) = d2; +} + +void fake_update_distance(MAT_DISTANCE2& distance, + const mln::value::rgb8& point1, + const mln::value::rgb8& point2, + const mln::value::rgb8& point3, + const mln::value::rgb8& point4, + const mln::value::rgb8& center1, + const mln::value::rgb8& center2) +{ + set_distance(distance, 0, dist(point1, center1), dist(point1, center2)); + set_distance(distance, 1, dist(point2, center1), dist(point2, center2)); + set_distance(distance, 2, dist(point3, center1), dist(point3, center2)); + set_distance(distance, 3, dist(point4, center1), dist(point4, center2)); + /* - int_u8 val; + for (int i = 0; i < SIZE_SAMPLE2; ++i) + for (int j = 0; j < NB_CENTER; ++j) + std::cout << "d(" << i << "," << j << ") = " << distance(i,j) <<std::endl; */ - /* - mln::trait::value::print<mln::value::int_u8>(std::cout); - mln::trait::image::print<mln::image2d<int_u8> >(); +} - mln::image2d<int_u8> img_out; +void test_update_group() +{ + mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean; + + const mln::value::rgb8 lime = mln::literal::lime; + const mln::value::rgb8 brown = mln::literal::brown; + const mln::value::rgb8 teal = mln::literal::teal; + const mln::value::rgb8 purple = mln::literal::purple; + const mln::value::rgb8 c1 = lime; + const mln::value::rgb8 c2 = purple; + const unsigned point1_min= 0; // lime near lime (c1) + const unsigned point2_min= 1; // brown near purple (c2) + const unsigned point3_min= 1; // teal near purple (c2) + const unsigned point4_min= 1; // purple near purple (c2) + const MAT_GROUP2& group = kmean.get_group(); + + fake_init_point(kmean.get_point(), lime, brown, teal, purple); + fake_init_center(kmean.get_center(), c1, c2); + fake_update_distance(kmean.get_distance(), lime, brown, teal, purple, c1, c2); + kmean.update_group(); + mln_assertion(0.0 == group(0, 1 - point1_min)); + mln_assertion(1.0 == group(0, point1_min)); + mln_assertion(0.0 == group(1, 1 - point2_min)); + mln_assertion(1.0 == group(1, point2_min)); + mln_assertion(0.0 == group(2, 1 - point3_min)); + mln_assertion(1.0 == group(2, point3_min)); + mln_assertion(0.0 == group(3, 1 - point4_min)); + mln_assertion(1.0 == group(3, point4_min)); - //mln::io::pgm::load(img_ref, "mp00082c_50p.pgm"); - mln::trace::exiting("image"); - + std::cout << "Test update group : ok" << std::endl; +} - mln::trace::entering("test"); +void set_group(MAT_GROUP2& group, + const unsigned index, + const unsigned min) +{ + group(index, min) = 1.0; + group(index, 1-min) = 0.0; +} - mln::trace::exiting("test"); - // k_mean.distance(); - */ + +void fake_update_group(MAT_GROUP2& group, + const unsigned& point1_min, + const unsigned& point2_min, + const unsigned& point3_min, + const unsigned& point4_min) +{ + set_group(group, 0, point1_min); + set_group(group, 1, point2_min); + set_group(group, 2, point3_min); + set_group(group, 3, point4_min); } -void test_col() +void test_update_center() { + mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean; + + const mln::value::rgb8 lime = mln::literal::lime; + const mln::value::rgb8 brown = mln::literal::brown; + const mln::value::rgb8 teal = mln::literal::teal; + const mln::value::rgb8 purple = mln::literal::purple; + const mln::value::rgb8 c1 = lime; + const mln::value::rgb8 c2 = purple; + const unsigned pt1_min = 0; // lime near lime (c1) + const unsigned pt2_min = 1; // brown near purple (c2) + const unsigned pt3_min = 1; // teal near purple (c2) + const unsigned pt4_min = 1; // purple near purple (c2) + const mln::value::rgb8 mean_c1 = lime; + const mln::value::rgb8 mean_c2 = (brown+teal+purple)/3; + + fake_init_point(kmean.get_point(), lime, brown, teal, purple); + fake_init_center(kmean.get_center(), c1, c2); + fake_update_distance(kmean.get_distance(), lime, brown, teal, purple, c1, c2); + fake_update_group(kmean.get_group(), pt1_min, pt2_min, pt3_min, pt4_min); + kmean.update_center(); + + mln_assertion(is_equal(mean_c1, kmean.get_center(), 0)); + mln_assertion(is_equal(mean_c2, kmean.get_center(), 1)); + + std::cout << "Test update center : ok" << std::endl; } -void test_row() +void test_update_var() { + mln::clustering::k_mean<SIZE_SAMPLE2, NB_CENTER, DIM_POINT, TYPE_POINT> kmean; + + const mln::value::rgb8 lime = mln::literal::lime; + const mln::value::rgb8 brown = mln::literal::brown; + const mln::value::rgb8 teal = mln::literal::teal; + const mln::value::rgb8 purple = mln::literal::purple; + const mln::value::rgb8 c1 = lime; + const mln::value::rgb8 c2 = purple; + const unsigned pt1_min = 0; // lime near lime (c1) + const unsigned pt2_min = 1; // brown near purple (c2) + const unsigned pt3_min = 1; // teal near purple (c2) + const unsigned pt4_min = 1; // purple near purple (c2) + const double v1 = 0; + const double v2 = dist(purple, brown) + dist(purple, teal); + const mln::value::rgb8 mean_c2 = (brown+teal+purple)/3; + const VEC_VAR& var = kmean.get_variance(); + + fake_init_point(kmean.get_point(), lime, brown, teal, purple); + fake_init_center(kmean.get_center(), c1, c2); + fake_update_distance(kmean.get_distance(), lime, brown, teal, purple, c1, c2); + fake_update_group(kmean.get_group(), pt1_min, pt2_min, pt3_min, pt4_min); + kmean.update_variance(); + + mln_assertion(v1 == var[0]); + mln_assertion(v2 == var[1]); + + std::cout << "Test update variance: ok" << std::endl; } + int main() { - //test_instantiation(); + test_instantiation(); test_init_point(); test_init_center(); + test_update_distance(); + test_update_group(); + test_update_center(); + test_update_var(); + + // mln::trace::quiet = false; + + test_update_center(); return 0; } diff --git a/trunk/milena/sandbox/green/k_mean/k_mean.hh b/trunk/milena/sandbox/green/k_mean/k_mean.hh index 89547ce..7e295bf 100644 --- a/trunk/milena/sandbox/green/k_mean/k_mean.hh +++ b/trunk/milena/sandbox/green/k_mean/k_mean.hh @@ -52,6 +52,7 @@ #include <mln/trace/exiting.hh> #include <mln/core/contract.hh> +#include <mln/trait/value_.hh> #include <mln/algebra/mat.hh> #include <mln/algebra/vec.hh> @@ -99,6 +100,9 @@ namespace mln algebra::mat<n, p, T>& get_point(); algebra::mat<k, p, T>& get_center(); + algebra::mat<n, k, T>& get_distance(); + algebra::mat<n, k, T>& get_group(); + algebra::vec<k, T>& get_variance(); k_mean(); ~k_mean(); @@ -106,6 +110,7 @@ namespace mln void update_distance(); void update_group(); void update_center(); + void update_variance(); template <unsigned q, typename M> M min(const algebra::vec<q, M>& x) const; @@ -138,6 +143,15 @@ namespace mln algebra::vec<q,T> col(const algebra::mat<r, q, T>& m, const unsigned _col) const; + template <unsigned q, unsigned r> + void div_col(algebra::mat<r, q, T>& m, + const unsigned _col, + const T value); + + template <unsigned q, unsigned r> + mln_sum(T) sum_row(const algebra::mat<r, q, T>& m, + const unsigned _row) const; + private: /// \brief _points contains the concatenation of every data points. /// @@ -163,6 +177,8 @@ namespace mln /// _center is a matrix KxP where K is the number of centers and P is the /// number of attributes. algebra::mat<k, p, T>* _center; + + algebra::vec<k, T>* _variance; }; #ifndef MLN_INCLUDE_ONLY @@ -190,6 +206,39 @@ namespace mln } template <unsigned n, unsigned k, unsigned p, typename T> + inline + algebra::mat<n, k, T>& + k_mean<n,k,p,T>::get_distance() + { + trace::entering("mln::clustering::k_mean::get_distance"); + trace::exiting("mln::clustering::k_mean::get_distance"); + + return *_distance; + } + + template <unsigned n, unsigned k, unsigned p, typename T> + inline + algebra::mat<n, k, T>& + k_mean<n,k,p,T>::get_group() + { + trace::entering("mln::clustering::k_mean::get_group"); + trace::exiting("mln::clustering::k_mean::get_group"); + + return *_group; + } + + template <unsigned n, unsigned k, unsigned p, typename T> + inline + algebra::vec<k, T>& + k_mean<n,k,p,T>::get_variance() + { + trace::entering("mln::clustering::k_mean::get_variance"); + trace::exiting("mln::clustering::k_mean::get_variance"); + + return *_variance; + } + + template <unsigned n, unsigned k, unsigned p, typename T> k_mean<n,k,p,T>::k_mean() { trace::entering("mln::clustering::k_mean::k_mean"); @@ -198,11 +247,13 @@ namespace mln _distance = new algebra::mat<n, k, mln_sum_product(T,T)>(); _group = new algebra::mat<n, k, T>(); _center = new algebra::mat<k, p, T>(); + _variance = new algebra::vec<k,T>(); mln_postcondition(_point != 0); mln_postcondition(_distance != 0); mln_postcondition(_group != 0); mln_postcondition(_center != 0); + mln_postcondition(_variance != 0); trace::exiting("mln::clustering::k_mean::k_mean"); } @@ -216,6 +267,7 @@ namespace mln delete _distance; delete _group; delete _center; + delete _variance; trace::exiting("mln::clustering::k_mean::~k_mean"); } @@ -273,7 +325,7 @@ namespace mln center(i,j) = point(random, j); } - std::cout << "center(" << i << ")" << col(center, i) << std::endl; + //std::cout << "center(" << i << ")" << col(center, i) << std::endl; } trace::exiting("mln::clustering::k_mean<n,k,p,T>::init_center"); @@ -329,12 +381,33 @@ namespace mln algebra::mat<k, p, T>& center = *_center; algebra::mat<n, k, T>& group = *_group; - center = (group.t() * point) / n; + center = (group.t() * point); + + for (unsigned i = 0; i < k; ++i) + div_col(center, i, sum_row(group, i)); - // mln_postcondition(sum(col(distance(i,j)) == 1) Vi trace::exiting("mln::clustering::k_mean<n,k,p,T>::update_center"); } + template <unsigned n, unsigned k, unsigned p, typename T> + inline + void k_mean<n,k,p,T>::update_variance() + { + trace::entering("mln::clustering::k_mean<n,k,p,T>::update_variance"); + + algebra::vec<k, T>& variance = *_variance; + algebra::mat<n, k, T>& distance = *_distance; + algebra::mat<n, k, T>& group = *_group; + + // BUG HERE + // separate the n group in n vectors + // separate the n distance in n vectors + // compute the n scalar product (group, distance) + // sum the n scalar product to obtain the within variance + variance = (group.t() * distance).t(); + + trace::exiting("mln::clustering::k_mean<n,k,p,T>::update_variance"); + } /* @@ -352,7 +425,40 @@ namespace mln } */ - //vec<q, T> k_mean<n,k,p,T>::col(const mat<r, q, T>& m, unsigned _col) const + template <unsigned n, unsigned k, unsigned p, typename T> + template <unsigned q, unsigned r> + inline + mln_sum(T) k_mean<n,k,p,T>::sum_row(const algebra::mat<r, q, T>& m, + const unsigned _row) const + { + trace::entering("mln::clustering::k_mean::sum_row"); + mln_precondition(q > _row); + + mln_sum(T) result; + + for (unsigned j = 0; j < r; ++j) + result += m(j, _row); + + trace::exiting("mln::clustering::k_mean::sum_row"); + return result; + } + + template <unsigned n, unsigned k, unsigned p, typename T> + template <unsigned q, unsigned r> + inline + void k_mean<n,k,p,T>::div_col(algebra::mat<r, q, T>& m, + const unsigned _col, + const T value) + { + trace::entering("mln::clustering::k_mean::div_col"); + mln_precondition(r > _col); + + for (unsigned j = 0; j < q; ++j) + m(_col, j) /= value; + + trace::exiting("mln::clustering::k_mean::div_col"); + } + template <unsigned n, unsigned k, unsigned p, typename T> template <unsigned q, unsigned r> inline @@ -390,8 +496,9 @@ namespace mln inline void k_mean<n,k,p,T>::update_distance() { - trace::entering("mln::clustering::k_mean::distance"); - mln::trace::quiet = true; + trace::entering("mln::clustering::k_mean::update_distance"); + //mln::trace::quiet = true; + // the result is stored in _distance matrix. algebra::mat<n, p, T>& point = *_point; algebra::mat<n, k, T>& distance = *_distance; @@ -406,8 +513,8 @@ namespace mln distance(i,j) = euclidian_distance(col(point,i),col(center,j)); } } - mln::trace::quiet = false; - trace::exiting("mln::clustering::k_mean::distance"); + //mln::trace::quiet = false; + trace::exiting("mln::clustering::k_mean::update_distance"); } #endif // ! MLN_INCLUDE_ONLY -- 1.5.6.5