k-means聚类算法C++实现
k-means:一种聚类算法,将样本集data[]分成给定的K个类。经过k-means聚类后,各类别内部的样本会尽可能的紧凑,而各类别之间的样本会尽可能的分开。
k-means思想:将距离最近的样本认为属于同一个类,每一个类有一个“质心”样本。举例,漫天繁星,距离较近的抱团星星,我们认为它们属于一个星团,而每个星团有一颗“质心”恒星作为这个星团的代表。
k-means计算过程:
1)初始化
1.1)初始化值:输入K值,输入data[]全集
1.2)初始化质心:从data[]全集中随机的选取K个样本,作为K个类的质心
1.3)初始化分类:对于随机选取的初始化质心,初始化每个样本的分类,将样本归入离它最近的那个质心那一类
2)迭代运算
2.1)质心变换:对于同一个类的样本集合,重新计算质心
2.2)分类变换:对于变换后的质心,所有样本重新计算分类,计算依据仍是“将样本归入离它最近的那个质心那一类”
2.3)反复的进行迭代运算,直至2.1)质心变换与2.2)分类变换都不再变化为止,理论可以证明,k-means聚类算法一定是收敛的
3)输出结果
k-means关键点:
1)距离:两个样本之间的距离如何定义,是和业务场景紧密相关的。如果样本是二维平面上的点,两个点之间的距离可以定义为二维欧式距离(Euclidean distance),如果样本是天空中的繁星,两颗繁星之间的举例可以定义为三维欧式距离。
2)质心变换:定义了距离之后,初始化分类时,会把样本聚为最近质心那一类。初始化分类后,如何进行质心变换呢?一般使用举例方差法:将同一类中的所有样本都尝试着作为“假定质心”,计算此时该类中所有样本与“假定质心”距离的方差,将方差最小的“假定质心”设为该类的新质心。
工程实现:
工程上对k-means实现,要尽量做到算法步骤与业务场景的解耦:
1)k-means的计算过程是与业务无关的
2)样本,以及样本之间距离的计算是与业务有关的
上述解耦可以利用C++的模版来实现,具体代码如下:
第一部分:算法的抽象
template< typename CElement>
class CK_means
{
public:
// 步骤一:初始化
int32_t Init(uint32_t k, const vector< CElement>& data);
// 步骤二:迭代计算
int32_t Run();
// 步骤三:输出结果
int32_t PrintResult();
private:
// 保存K值
uint32_t m_k;
// 保存K个类的所有样本
vector< vector< CElement> > m_result;
// 保存K个质心
vector< CElement> m_center;
private:
// 内部函数:计算两个样本的距离,具体的实现方法是业务方定义的
uint32_t GetDistance(CElement a, CElement b);
// 内部函数:计算某个样本离哪个质心最近,即找到它所属的分类
uint32_t GetElementNearestCenter(CElement a);
// 内部函数:打印当前的质心
void PrintCenters();
};
第二部分:业务的抽象
class IElement
{
public:
// 需要业务定义两个样本之间的距离
virtual uint32_t GetDistance(IElement* b) = 0;
public:
// 测试用,打印自己
virtual void PrintSelf() = 0;
// 语言要求,无参数构造函数
IElement()
{
}
// 语言要求,虚析构函数
virtual ~IElement()
{
}
// 语言要求,运算符重载
IElement& operator=(const IElement& in);
};
第三部分:算法的实现
template< typename CElement>
class CK_means
{
public:
// 步骤一:初始化
int32_t Init(uint32_t k, const vector< CElement>& data)
{
// K过小,或者样本集过小
if(k==0 || k==1 || k>=data.size())
{
return -1;
}
// 初始化K
m_k = k;
m_result.resize(k);
m_center.resize(k);
// 初始化质心
for(uint32_t i=0; i< m_k; ++i)
{
m_center[i] = data[i];
}
PrintCenters();
// 初始化分类
uint32_t dataSize = data.size();
uint32_t elementClass = 0;
for(uint32_t i=0; i< dataSize; ++i)
{
// 找到最近的质心
elementClass = GetElementNearestCenter(data[i]);
// 归入最近质心那一类
m_result[elementClass].push_back(data[i]);
}
return 0;
}
// 步骤二:迭代计算
int32_t Run()
{
// 迭代结束标记
uint32_t notEnd = 1;
while(notEnd)
{
//(1) 质心变换
for(uint32_t i=0;i< m_k;++i)
{
vector< CElement>& t = m_result[i];
uint32_t tSize = t.size();
// (1.1)遍历每一个样本,保存它作为质心时,该类距离的方差
vector< uint32_t> variance;
variance.resize(tSize);
uint32_t dis = 0;
for(uint32_t j=0; j< tSize; ++j)
{
dis = 0;
CElement& ele = t[j];
for(uint32_t k=0; k< tSize; ++k)
{
uint32_t tempDis = GetDistance(ele, t[k]);
dis += (tempDis * tempDis);
}
variance[j] = dis;
}
// (1.2)找到最小的方差,此时的“假定质心”会被选取为新的质心
uint32_t minVariance = variance[0];
uint32_t newCenterIndex = 0;
for(uint32_t m=0; m< tSize; ++m)
{
if(variance[m] < minVariance)
{
minVariance = variance[m];
newCenterIndex = m;
}
}
m_center[i]=t[newCenterIndex];
}
PrintCenters();
//(2) 分类变换
notEnd = 0;
vector< vector< CElement> > newResult;
newResult.resize(m_k);
uint32_t newClass = 0;
for(uint32_t i=0; i< m_k; ++i)
{
vector< CElement>& t = m_result[i];
uint32_t tSize = t.size();
for(uint32_t j=0; j< tSize; ++j)
{
// (2.1)对于新的质心,每个元素重新计算分类
newClass = GetElementNearestCenter(t[j]);
// 分类发生了改变,说明迭代不能结束
if(newClass != i)
{
notEnd = 1;
// 该样本放入新的类
newResult[newClass].push_back(t[j]);
}
else
{
// 该样本仍然放入旧的类
newResult[i].push_back(t[j]);
}
}
}
// (3) 计算出来的结果保存下来
for(uint32_t i=0; i< m_k; ++i)
{
m_result[i].clear();
}
for(uint32_t i=0; i< m_k; ++i)
{
vector< CElement>& t = newResult[i];
uint32_t tSize = t.size();
for(uint32_t j=0; j< tSize; ++j)
{
m_result[i].push_back(t[j]);
}
}
// (4) 如果迭代没有结束while(notEnd)进入下一次迭代
}
}
// 步骤三:输出结果
int32_t PrintResult()
{
for(uint32_t i=0; i< m_k; ++i)
{
printf("result %u",i);
m_center[i].PrintSelf();
printf(":");
vector< CElement>& t = m_result[i];
uint32_t tSize = t.size();
for(uint32_t j=0; j< tSize; ++j)
{
t[j].PrintSelf();
}
printf("\n");
}
return 0;
}
private:
uint32_t m_k;
vector< vector< CElement> > m_result;
vector< CElement> m_center;
private:
// 内部函数:计算两个样本的距离,具体的实现方法是业务方定义的
uint32_t GetDistance(CElement a, CElement b)
{
return a.GetDistance(&b);
}
// 内部函数:计算某个样本离哪个质心最近,即找到它所属的分类
uint32_t GetElementNearestCenter(CElement a)
{
uint32_t minDistance = (uint32_t)(-1);
uint32_t minCenter = 0;
uint32_t nowDistance = 0;
for(uint32_t i=0; i< m_k; ++i)
{
nowDistance = m_center[i].GetDistance(&a);
if(nowDistance < minDistance)
{
minDistance = nowDistance;
minCenter = i;
}
}
return minCenter;
}
// 内部函数:打印当前的质心
void PrintCenters()
{
static uint32_t count = 0;
printf("running %u",count++);
for(uint32_t i=0;i< m_k;++i)
{
m_center[i].PrintSelf();
}
printf("\n");
}
};
第四部分:业务的实现举例
// 该业务是二维平面上的点,需要继承抽样业务类IElement
class CPoint : public IElement
{
public:
// 必须实现默认无参数构造函数
CPoint()
{
this->x = 0;
this->y = 0;
}
// 也可以定义自己的构造函数
CPoint(uint32_t x, uint32_t y)
{
this->x = x;
this->y = y;
}
// 必须定义虚构造函数
virtual ~CPoint()
{
}
// 为了配合测试,必须实现PrintSelf
void PrintSelf()
{
printf("(%u,%u)",x,y);
}
// 业务的核心,必须实现如何求两个样本之间的距离
uint32_t GetDistance(IElement* b)
{
CPoint* p = dynamic_cast< CPoint*>(b);
int32_t xDif = x - p->x;
int32_t yDif = y - p->y;
return xDif*xDif + yDif*yDif;
}
// 必须重载运算符
CPoint& operator=(const CPoint& in)
{
this->x = in.x;
this->y = in.y;
return *this;
}
private:
// 二维平面上的点
uint32_t x;
uint32_t y;
};
第五部分:测试代码
int main()
{
// 生成测试数据
vector< CPoint> points;
points.push_back(CPoint(0,0));
#define NUM (5)
#define DEF (50)
uint32_t x=0;
uint32_t y=0;
srand((unsigned)time(NULL));
for(uint32_t i=0;i< NUM;++i)
{
x = random()%DEF;
y = random()%DEF;
points.push_back(CPoint(x,y));
points.push_back(CPoint(x+DEF,y+DEF));
points.push_back(CPoint(x+DEF*2,y+DEF*2));
}
// 调用k-means并输出结果
CK_means< CPoint> kmeans;
kmeans.Init(3, points);
kmeans.Run();
kmeans.PrintResult();
return 0;
}
第六部分:结果输出
[shenjian@dev02 k-means]$ g++ k_means.cpp
[shenjian@dev02 k-means]$ ./a.out
running 0(0,0)(5,47)(55,97)
running 1(0,0)(14,39)(92,97)
result 0(0,0):(0,0)
result 1(14,39):(5,47)(14,39)(35,28)(42,47)(13,36)
result 2(92,97):(55,97)(105,147)(64,89)(114,139)(85,78)(135,128)(92,97)(142,147)(63,86)(113,136)
[shenjian@dev02 k-means]$ ./a.out
running 0(0,0)(25,48)(75,98)
running 1(2,7)(47,34)(102,107)
running 2(2,7)(52,57)(125,112)
running 3(2,7)(52,57)(125,112)
result 0(2,7):(0,0)(25,12)(2,7)
result 1(52,57):(25,48)(48,27)(47,34)(52,57)(75,62)(75,98)
result 2(125,112):(125,148)(125,112)(98,77)(148,127)(97,84)(147,134)(102,107)
附录:所有代码
#include< stdlib.h>
#include< stdio.h>
#include< stdint.h>
#include< time.h>
#include< vector>
// by shenjian 20140620
// qq 396009594
using namespace std;
class IElement
{
public:
virtual void PrintSelf() = 0;
virtual uint32_t GetDistance(IElement* b) = 0;
IElement()
{
}
virtual ~IElement()
{
}
IElement& operator=(const IElement& in);
};
class CPoint : public IElement
{
public:
CPoint()
{
this->x = 0;
this->y = 0;
}
CPoint(uint32_t x, uint32_t y)
{
this->x = x;
this->y = y;
}
virtual ~CPoint()
{
}
void PrintSelf()
{
printf("(%u,%u)",x,y);
}
uint32_t GetDistance(IElement* b)
{
CPoint* p = dynamic_cast< CPoint*>(b);
int32_t xDif = x - p->x;
int32_t yDif = y - p->y;
return xDif*xDif + yDif*yDif;
}
CPoint& operator=(const CPoint& in)
{
this->x = in.x;
this->y = in.y;
return *this;
}
private:
uint32_t x;
uint32_t y;
};
template< typename CElement>
class CK_means
{
public:
int32_t Init(uint32_t k, const vector< CElement>& data)
{
if(k==0 || k==1 || k>=data.size())
{
return -1;
}
m_k = k;
m_result.resize(k);
m_center.resize(k);
// init center
for(uint32_t i=0; i< m_k; ++i)
{
m_center[i] = data[i];
}
PrintCenters();
// init element classification
uint32_t dataSize = data.size();
uint32_t elementClass = 0;
for(uint32_t i=0; i< dataSize; ++i)
{
elementClass = GetElementNearestCenter(data[i]);
m_result[elementClass].push_back(data[i]);
}
return 0;
}
int32_t Run()
{
uint32_t notEnd = 1;
while(notEnd)
{
//(1)caculate new center for each classification
for(uint32_t i=0;i< m_k;++i)
{
vector< CElement>& t = m_result[i];
uint32_t tSize = t.size();
// (1.1)if one is center, keep variance
vector< uint32_t> variance;
variance.resize(tSize);
uint32_t dis = 0;
for(uint32_t j=0; j< tSize; ++j)
{
dis = 0;
CElement& ele = t[j];
for(uint32_t k=0; k< tSize; ++k)
{
uint32_t tempDis = GetDistance(ele, t[k]);
dis += (tempDis * tempDis);
}
variance[j] = dis;
}
// (1.2)which variance is the smallest, which is new center
uint32_t minVariance = variance[0];
uint32_t newCenterIndex = 0;
for(uint32_t m=0; m< tSize; ++m)
{
if(variance[m] < minVariance)
{
minVariance = variance[m];
newCenterIndex = m;
}
}
m_center[i]=t[newCenterIndex];
}
PrintCenters();
//(2) check is there any classification for new center
notEnd = 0;
vector< vector< CElement> > newResult;
newResult.resize(m_k);
uint32_t newClass = 0;
for(uint32_t i=0; i< m_k; ++i)
{
vector< CElement>& t = m_result[i];
uint32_t tSize = t.size();
for(uint32_t j=0; j< tSize; ++j)
{
// (2.1)caculate new classification for each element
newClass = GetElementNearestCenter(t[j]);
// if classification change
if(newClass != i)
{
notEnd = 1;
newResult[newClass].push_back(t[j]);
}
// if not change
else
{
newResult[i].push_back(t[j]);
}
}
}
// (3) refresh result
for(uint32_t i=0; i< m_k; ++i)
{
m_result[i].clear();
}
for(uint32_t i=0; i< m_k; ++i)
{
vector< CElement>& t = newResult[i];
uint32_t tSize = t.size();
for(uint32_t j=0; j< tSize; ++j)
{
m_result[i].push_back(t[j]);
}
}
// (4)start next ireration if notEnd = 1
}
}
int32_t PrintResult()
{
for(uint32_t i=0; i< m_k; ++i)
{
printf("result %u",i);
m_center[i].PrintSelf();
printf(":");
vector< CElement>& t = m_result[i];
uint32_t tSize = t.size();
for(uint32_t j=0; j< tSize; ++j)
{
t[j].PrintSelf();
}
printf("\n");
}
return 0;
}
private:
uint32_t m_k;
vector< vector< CElement> > m_result;
vector< CElement> m_center;
private:
// FUNC : caculate distance between a and b
// IN : element a+b
// OUT : distance
uint32_t GetDistance(CElement a, CElement b)
{
return a.GetDistance(&b);
}
// FUNC : caculate which classification that one element belongs to
// IN : element a
// OUT : which classification that a belongs to
// return range is [0, k)
uint32_t GetElementNearestCenter(CElement a)
{
uint32_t minDistance = (uint32_t)(-1);
uint32_t minCenter = 0;
uint32_t nowDistance = 0;
for(uint32_t i=0; i< m_k; ++i)
{
nowDistance = m_center[i].GetDistance(&a);
if(nowDistance < minDistance)
{
minDistance = nowDistance;
minCenter = i;
}
}
return minCenter;
}
// FUNC : for test, print all centers
void PrintCenters()
{
static uint32_t count = 0;
printf("running %u",count++);
for(uint32_t i=0;i< m_k;++i)
{
m_center[i].PrintSelf();
}
printf("\n");
}
};
int main()
{
// gen test data
vector< CPoint> points;
points.push_back(CPoint(0,0));
#define NUM (5)
#define DEF (50)
uint32_t x=0;
uint32_t y=0;
srand((unsigned)time(NULL));
for(uint32_t i=0;i< NUM;++i)
{
x = random()%DEF;
y = random()%DEF;
points.push_back(CPoint(x,y));
points.push_back(CPoint(x+DEF,y+DEF));
points.push_back(CPoint(x+DEF*2,y+DEF*2));
}
// k-means
CK_means< CPoint> kmeans;
kmeans.Init(3, points);
kmeans.Run();
kmeans.PrintResult();
return 0;
}
下面提供了源码的下载,wordpress不允许上传cpp文件,故加了一个.txt的后缀哈。
点击k-means.cpp下载源码。
评论关闭。