k-means聚类算法C++实现
k-means:一种聚类算法,将样本集data[]分成给定的K个类。经过k-means聚类后,各类别内部的样本会尽可能的紧凑,而各类别之间的样本会尽可能的分开。
k-means思想:将距离最近的样本认为属于同一个类,每一个类有一个“质心”样本。举例,漫天繁星,距离较近的抱团星星,我们认为它们属于一个星团,而每个星团有一颗“质心”恒星作为这个星团的代表。
k-means计算过程:
1)初始化
1.1)初始化值:输入K值,输入data[]全集
1.2)初始化质心:从data[]全集中随机的选取K个样本,作为K个类的质心
1.3)初始化分类:对于随机选取的初始化质心,初始化每个样本的分类,将样本归入离它最近的那个质心那一类
2)迭代运算
2.1)质心变换:对于同一个类的样本集合,重新计算质心
2.2)分类变换:对于变换后的质心,所有样本重新计算分类,计算依据仍是“将样本归入离它最近的那个质心那一类”
2.3)反复的进行迭代运算,直至2.1)质心变换与2.2)分类变换都不再变化为止,理论可以证明,k-means聚类算法一定是收敛的
3)输出结果
k-means关键点:
1)距离:两个样本之间的距离如何定义,是和业务场景紧密相关的。如果样本是二维平面上的点,两个点之间的距离可以定义为二维欧式距离(Euclidean distance),如果样本是天空中的繁星,两颗繁星之间的举例可以定义为三维欧式距离。
2)质心变换:定义了距离之后,初始化分类时,会把样本聚为最近质心那一类。初始化分类后,如何进行质心变换呢?一般使用举例方差法:将同一类中的所有样本都尝试着作为“假定质心”,计算此时该类中所有样本与“假定质心”距离的方差,将方差最小的“假定质心”设为该类的新质心。
工程实现:
工程上对k-means实现,要尽量做到算法步骤与业务场景的解耦:
1)k-means的计算过程是与业务无关的
2)样本,以及样本之间距离的计算是与业务有关的
上述解耦可以利用C++的模版来实现,具体代码如下:
第一部分:算法的抽象
template< typename CElement> class CK_means { public: // 步骤一:初始化 int32_t Init(uint32_t k, const vector< CElement>& data); // 步骤二:迭代计算 int32_t Run(); // 步骤三:输出结果 int32_t PrintResult(); private: // 保存K值 uint32_t m_k; // 保存K个类的所有样本 vector< vector< CElement> > m_result; // 保存K个质心 vector< CElement> m_center; private: // 内部函数:计算两个样本的距离,具体的实现方法是业务方定义的 uint32_t GetDistance(CElement a, CElement b); // 内部函数:计算某个样本离哪个质心最近,即找到它所属的分类 uint32_t GetElementNearestCenter(CElement a); // 内部函数:打印当前的质心 void PrintCenters(); };
第二部分:业务的抽象
class IElement { public: // 需要业务定义两个样本之间的距离 virtual uint32_t GetDistance(IElement* b) = 0; public: // 测试用,打印自己 virtual void PrintSelf() = 0; // 语言要求,无参数构造函数 IElement() { } // 语言要求,虚析构函数 virtual ~IElement() { } // 语言要求,运算符重载 IElement& operator=(const IElement& in); };
第三部分:算法的实现
template< typename CElement> class CK_means { public: // 步骤一:初始化 int32_t Init(uint32_t k, const vector< CElement>& data) { // K过小,或者样本集过小 if(k==0 || k==1 || k>=data.size()) { return -1; } // 初始化K m_k = k; m_result.resize(k); m_center.resize(k); // 初始化质心 for(uint32_t i=0; i< m_k; ++i) { m_center[i] = data[i]; } PrintCenters(); // 初始化分类 uint32_t dataSize = data.size(); uint32_t elementClass = 0; for(uint32_t i=0; i< dataSize; ++i) { // 找到最近的质心 elementClass = GetElementNearestCenter(data[i]); // 归入最近质心那一类 m_result[elementClass].push_back(data[i]); } return 0; } // 步骤二:迭代计算 int32_t Run() { // 迭代结束标记 uint32_t notEnd = 1; while(notEnd) { //(1) 质心变换 for(uint32_t i=0;i< m_k;++i) { vector< CElement>& t = m_result[i]; uint32_t tSize = t.size(); // (1.1)遍历每一个样本,保存它作为质心时,该类距离的方差 vector< uint32_t> variance; variance.resize(tSize); uint32_t dis = 0; for(uint32_t j=0; j< tSize; ++j) { dis = 0; CElement& ele = t[j]; for(uint32_t k=0; k< tSize; ++k) { uint32_t tempDis = GetDistance(ele, t[k]); dis += (tempDis * tempDis); } variance[j] = dis; } // (1.2)找到最小的方差,此时的“假定质心”会被选取为新的质心 uint32_t minVariance = variance[0]; uint32_t newCenterIndex = 0; for(uint32_t m=0; m< tSize; ++m) { if(variance[m] < minVariance) { minVariance = variance[m]; newCenterIndex = m; } } m_center[i]=t[newCenterIndex]; } PrintCenters(); //(2) 分类变换 notEnd = 0; vector< vector< CElement> > newResult; newResult.resize(m_k); uint32_t newClass = 0; for(uint32_t i=0; i< m_k; ++i) { vector< CElement>& t = m_result[i]; uint32_t tSize = t.size(); for(uint32_t j=0; j< tSize; ++j) { // (2.1)对于新的质心,每个元素重新计算分类 newClass = GetElementNearestCenter(t[j]); // 分类发生了改变,说明迭代不能结束 if(newClass != i) { notEnd = 1; // 该样本放入新的类 newResult[newClass].push_back(t[j]); } else { // 该样本仍然放入旧的类 newResult[i].push_back(t[j]); } } } // (3) 计算出来的结果保存下来 for(uint32_t i=0; i< m_k; ++i) { m_result[i].clear(); } for(uint32_t i=0; i< m_k; ++i) { vector< CElement>& t = newResult[i]; uint32_t tSize = t.size(); for(uint32_t j=0; j< tSize; ++j) { m_result[i].push_back(t[j]); } } // (4) 如果迭代没有结束while(notEnd)进入下一次迭代 } } // 步骤三:输出结果 int32_t PrintResult() { for(uint32_t i=0; i< m_k; ++i) { printf("result %u",i); m_center[i].PrintSelf(); printf(":"); vector< CElement>& t = m_result[i]; uint32_t tSize = t.size(); for(uint32_t j=0; j< tSize; ++j) { t[j].PrintSelf(); } printf("\n"); } return 0; } private: uint32_t m_k; vector< vector< CElement> > m_result; vector< CElement> m_center; private: // 内部函数:计算两个样本的距离,具体的实现方法是业务方定义的 uint32_t GetDistance(CElement a, CElement b) { return a.GetDistance(&b); } // 内部函数:计算某个样本离哪个质心最近,即找到它所属的分类 uint32_t GetElementNearestCenter(CElement a) { uint32_t minDistance = (uint32_t)(-1); uint32_t minCenter = 0; uint32_t nowDistance = 0; for(uint32_t i=0; i< m_k; ++i) { nowDistance = m_center[i].GetDistance(&a); if(nowDistance < minDistance) { minDistance = nowDistance; minCenter = i; } } return minCenter; } // 内部函数:打印当前的质心 void PrintCenters() { static uint32_t count = 0; printf("running %u",count++); for(uint32_t i=0;i< m_k;++i) { m_center[i].PrintSelf(); } printf("\n"); } };
第四部分:业务的实现举例
// 该业务是二维平面上的点,需要继承抽样业务类IElement class CPoint : public IElement { public: // 必须实现默认无参数构造函数 CPoint() { this->x = 0; this->y = 0; } // 也可以定义自己的构造函数 CPoint(uint32_t x, uint32_t y) { this->x = x; this->y = y; } // 必须定义虚构造函数 virtual ~CPoint() { } // 为了配合测试,必须实现PrintSelf void PrintSelf() { printf("(%u,%u)",x,y); } // 业务的核心,必须实现如何求两个样本之间的距离 uint32_t GetDistance(IElement* b) { CPoint* p = dynamic_cast< CPoint*>(b); int32_t xDif = x - p->x; int32_t yDif = y - p->y; return xDif*xDif + yDif*yDif; } // 必须重载运算符 CPoint& operator=(const CPoint& in) { this->x = in.x; this->y = in.y; return *this; } private: // 二维平面上的点 uint32_t x; uint32_t y; };
第五部分:测试代码
int main() { // 生成测试数据 vector< CPoint> points; points.push_back(CPoint(0,0)); #define NUM (5) #define DEF (50) uint32_t x=0; uint32_t y=0; srand((unsigned)time(NULL)); for(uint32_t i=0;i< NUM;++i) { x = random()%DEF; y = random()%DEF; points.push_back(CPoint(x,y)); points.push_back(CPoint(x+DEF,y+DEF)); points.push_back(CPoint(x+DEF*2,y+DEF*2)); } // 调用k-means并输出结果 CK_means< CPoint> kmeans; kmeans.Init(3, points); kmeans.Run(); kmeans.PrintResult(); return 0; }
第六部分:结果输出
[shenjian@dev02 k-means]$ g++ k_means.cpp
[shenjian@dev02 k-means]$ ./a.out
running 0(0,0)(5,47)(55,97)
running 1(0,0)(14,39)(92,97)
result 0(0,0):(0,0)
result 1(14,39):(5,47)(14,39)(35,28)(42,47)(13,36)
result 2(92,97):(55,97)(105,147)(64,89)(114,139)(85,78)(135,128)(92,97)(142,147)(63,86)(113,136)
[shenjian@dev02 k-means]$ ./a.out
running 0(0,0)(25,48)(75,98)
running 1(2,7)(47,34)(102,107)
running 2(2,7)(52,57)(125,112)
running 3(2,7)(52,57)(125,112)
result 0(2,7):(0,0)(25,12)(2,7)
result 1(52,57):(25,48)(48,27)(47,34)(52,57)(75,62)(75,98)
result 2(125,112):(125,148)(125,112)(98,77)(148,127)(97,84)(147,134)(102,107)
附录:所有代码
#include< stdlib.h> #include< stdio.h> #include< stdint.h> #include< time.h> #include< vector> // by shenjian 20140620 // qq 396009594 using namespace std; class IElement { public: virtual void PrintSelf() = 0; virtual uint32_t GetDistance(IElement* b) = 0; IElement() { } virtual ~IElement() { } IElement& operator=(const IElement& in); }; class CPoint : public IElement { public: CPoint() { this->x = 0; this->y = 0; } CPoint(uint32_t x, uint32_t y) { this->x = x; this->y = y; } virtual ~CPoint() { } void PrintSelf() { printf("(%u,%u)",x,y); } uint32_t GetDistance(IElement* b) { CPoint* p = dynamic_cast< CPoint*>(b); int32_t xDif = x - p->x; int32_t yDif = y - p->y; return xDif*xDif + yDif*yDif; } CPoint& operator=(const CPoint& in) { this->x = in.x; this->y = in.y; return *this; } private: uint32_t x; uint32_t y; }; template< typename CElement> class CK_means { public: int32_t Init(uint32_t k, const vector< CElement>& data) { if(k==0 || k==1 || k>=data.size()) { return -1; } m_k = k; m_result.resize(k); m_center.resize(k); // init center for(uint32_t i=0; i< m_k; ++i) { m_center[i] = data[i]; } PrintCenters(); // init element classification uint32_t dataSize = data.size(); uint32_t elementClass = 0; for(uint32_t i=0; i< dataSize; ++i) { elementClass = GetElementNearestCenter(data[i]); m_result[elementClass].push_back(data[i]); } return 0; } int32_t Run() { uint32_t notEnd = 1; while(notEnd) { //(1)caculate new center for each classification for(uint32_t i=0;i< m_k;++i) { vector< CElement>& t = m_result[i]; uint32_t tSize = t.size(); // (1.1)if one is center, keep variance vector< uint32_t> variance; variance.resize(tSize); uint32_t dis = 0; for(uint32_t j=0; j< tSize; ++j) { dis = 0; CElement& ele = t[j]; for(uint32_t k=0; k< tSize; ++k) { uint32_t tempDis = GetDistance(ele, t[k]); dis += (tempDis * tempDis); } variance[j] = dis; } // (1.2)which variance is the smallest, which is new center uint32_t minVariance = variance[0]; uint32_t newCenterIndex = 0; for(uint32_t m=0; m< tSize; ++m) { if(variance[m] < minVariance) { minVariance = variance[m]; newCenterIndex = m; } } m_center[i]=t[newCenterIndex]; } PrintCenters(); //(2) check is there any classification for new center notEnd = 0; vector< vector< CElement> > newResult; newResult.resize(m_k); uint32_t newClass = 0; for(uint32_t i=0; i< m_k; ++i) { vector< CElement>& t = m_result[i]; uint32_t tSize = t.size(); for(uint32_t j=0; j< tSize; ++j) { // (2.1)caculate new classification for each element newClass = GetElementNearestCenter(t[j]); // if classification change if(newClass != i) { notEnd = 1; newResult[newClass].push_back(t[j]); } // if not change else { newResult[i].push_back(t[j]); } } } // (3) refresh result for(uint32_t i=0; i< m_k; ++i) { m_result[i].clear(); } for(uint32_t i=0; i< m_k; ++i) { vector< CElement>& t = newResult[i]; uint32_t tSize = t.size(); for(uint32_t j=0; j< tSize; ++j) { m_result[i].push_back(t[j]); } } // (4)start next ireration if notEnd = 1 } } int32_t PrintResult() { for(uint32_t i=0; i< m_k; ++i) { printf("result %u",i); m_center[i].PrintSelf(); printf(":"); vector< CElement>& t = m_result[i]; uint32_t tSize = t.size(); for(uint32_t j=0; j< tSize; ++j) { t[j].PrintSelf(); } printf("\n"); } return 0; } private: uint32_t m_k; vector< vector< CElement> > m_result; vector< CElement> m_center; private: // FUNC : caculate distance between a and b // IN : element a+b // OUT : distance uint32_t GetDistance(CElement a, CElement b) { return a.GetDistance(&b); } // FUNC : caculate which classification that one element belongs to // IN : element a // OUT : which classification that a belongs to // return range is [0, k) uint32_t GetElementNearestCenter(CElement a) { uint32_t minDistance = (uint32_t)(-1); uint32_t minCenter = 0; uint32_t nowDistance = 0; for(uint32_t i=0; i< m_k; ++i) { nowDistance = m_center[i].GetDistance(&a); if(nowDistance < minDistance) { minDistance = nowDistance; minCenter = i; } } return minCenter; } // FUNC : for test, print all centers void PrintCenters() { static uint32_t count = 0; printf("running %u",count++); for(uint32_t i=0;i< m_k;++i) { m_center[i].PrintSelf(); } printf("\n"); } }; int main() { // gen test data vector< CPoint> points; points.push_back(CPoint(0,0)); #define NUM (5) #define DEF (50) uint32_t x=0; uint32_t y=0; srand((unsigned)time(NULL)); for(uint32_t i=0;i< NUM;++i) { x = random()%DEF; y = random()%DEF; points.push_back(CPoint(x,y)); points.push_back(CPoint(x+DEF,y+DEF)); points.push_back(CPoint(x+DEF*2,y+DEF*2)); } // k-means CK_means< CPoint> kmeans; kmeans.Init(3, points); kmeans.Run(); kmeans.PrintResult(); return 0; }
下面提供了源码的下载,wordpress不允许上传cpp文件,故加了一个.txt的后缀哈。
点击k-means.cpp下载源码。
评论关闭。