k-means聚类算法C++实现

k-means聚类算法C++实现

k-means:一种聚类算法,将样本集data[]分成给定的K个类。经过k-means聚类后,各类别内部的样本会尽可能的紧凑,而各类别之间的样本会尽可能的分开。
k-means思想:将距离最近的样本认为属于同一个类,每一个类有一个“质心”样本。举例,漫天繁星,距离较近的抱团星星,我们认为它们属于一个星团,而每个星团有一颗“质心”恒星作为这个星团的代表。
k-means计算过程
1)初始化
1.1)初始化值:输入K值,输入data[]全集
1.2)初始化质心:从data[]全集中随机的选取K个样本,作为K个类的质心
1.3)初始化分类:对于随机选取的初始化质心,初始化每个样本的分类,将样本归入离它最近的那个质心那一类
2)迭代运算
2.1)质心变换:对于同一个类的样本集合,重新计算质心
2.2)分类变换:对于变换后的质心,所有样本重新计算分类,计算依据仍是“将样本归入离它最近的那个质心那一类”
2.3)反复的进行迭代运算,直至2.1)质心变换与2.2)分类变换都不再变化为止,理论可以证明,k-means聚类算法一定是收敛的
3)输出结果
k-means关键点
1)距离:两个样本之间的距离如何定义,是和业务场景紧密相关的。如果样本是二维平面上的点,两个点之间的距离可以定义为二维欧式距离(Euclidean distance),如果样本是天空中的繁星,两颗繁星之间的举例可以定义为三维欧式距离。
2)质心变换:定义了距离之后,初始化分类时,会把样本聚为最近质心那一类。初始化分类后,如何进行质心变换呢?一般使用举例方差法:将同一类中的所有样本都尝试着作为“假定质心”,计算此时该类中所有样本与“假定质心”距离的方差,将方差最小的“假定质心”设为该类的新质心。

工程实现
工程上对k-means实现,要尽量做到算法步骤与业务场景的解耦:
1)k-means的计算过程是与业务无关的
2)样本,以及样本之间距离的计算是与业务有关的
上述解耦可以利用C++的模版来实现,具体代码如下:
第一部分:算法的抽象

template< typename CElement>
class CK_means
{
    public:
        // 步骤一:初始化
        int32_t Init(uint32_t k, const vector< CElement>& data);
        // 步骤二:迭代计算
	int32_t Run();
        // 步骤三:输出结果
        int32_t PrintResult();

    private:
        // 保存K值
        uint32_t m_k;
        // 保存K个类的所有样本
        vector< vector< CElement> > m_result;
        // 保存K个质心
        vector< CElement> m_center;

    private:
        // 内部函数:计算两个样本的距离,具体的实现方法是业务方定义的
        uint32_t GetDistance(CElement a, CElement b);
        // 内部函数:计算某个样本离哪个质心最近,即找到它所属的分类
        uint32_t GetElementNearestCenter(CElement a);
        // 内部函数:打印当前的质心
        void PrintCenters();
};

第二部分:业务的抽象

class IElement
{
    public:
        // 需要业务定义两个样本之间的距离
        virtual uint32_t GetDistance(IElement* b) = 0;

    public:
        // 测试用,打印自己
	virtual void PrintSelf() = 0;
        // 语言要求,无参数构造函数
        IElement()
        {
        }
        // 语言要求,虚析构函数
        virtual ~IElement()
        {
        }
        // 语言要求,运算符重载
        IElement& operator=(const IElement& in);
};

第三部分:算法的实现

template< typename CElement>
class CK_means
{
    public:
        // 步骤一:初始化
        int32_t Init(uint32_t k, const vector< CElement>& data)
        {
            // K过小,或者样本集过小
            if(k==0 || k==1 || k>=data.size())
            {
                return -1;
            }

            // 初始化K
            m_k = k;
            m_result.resize(k);
            m_center.resize(k);

            // 初始化质心
            for(uint32_t i=0; i< m_k; ++i)
            {
                m_center[i] = data[i];
            }
            PrintCenters();

            // 初始化分类
            uint32_t dataSize = data.size();
            uint32_t elementClass = 0;
            for(uint32_t i=0; i< dataSize; ++i)
            {
                // 找到最近的质心
                elementClass = GetElementNearestCenter(data[i]);
                // 归入最近质心那一类
                m_result[elementClass].push_back(data[i]);
            }

            return 0;
        }

        // 步骤二:迭代计算
        int32_t Run()
        {
            // 迭代结束标记
            uint32_t notEnd = 1;
            while(notEnd)
            {
                //(1) 质心变换
                for(uint32_t i=0;i< m_k;++i)
                {
                    vector< CElement>& t = m_result[i];
                    uint32_t tSize = t.size();

                    // (1.1)遍历每一个样本,保存它作为质心时,该类距离的方差
                    vector< uint32_t> variance;
                    variance.resize(tSize);
                    uint32_t dis = 0;
                    for(uint32_t j=0; j< tSize; ++j)
                    {
                        dis = 0;
                        CElement& ele = t[j];
                        for(uint32_t k=0; k< tSize; ++k)
                        {
                            uint32_t tempDis = GetDistance(ele, t[k]);
                            dis += (tempDis * tempDis);
                        }
                        variance[j] = dis;
                    }

                    // (1.2)找到最小的方差,此时的“假定质心”会被选取为新的质心
                    uint32_t minVariance = variance[0];
                    uint32_t newCenterIndex = 0;
                    for(uint32_t m=0; m< tSize; ++m)
                    {
                        if(variance[m] < minVariance)
                        {
                            minVariance = variance[m];
                            newCenterIndex = m;
                        }
                    }
                    m_center[i]=t[newCenterIndex];
                }
                PrintCenters();

                //(2) 分类变换
                notEnd = 0;
                vector< vector< CElement> > newResult;
                newResult.resize(m_k);
                uint32_t newClass = 0;
                for(uint32_t i=0; i< m_k; ++i)
                {
                    vector< CElement>& t = m_result[i];
                    uint32_t tSize = t.size();
                    for(uint32_t j=0; j< tSize; ++j)
                    {
                        // (2.1)对于新的质心,每个元素重新计算分类
                        newClass = GetElementNearestCenter(t[j]);
                        // 分类发生了改变,说明迭代不能结束
                        if(newClass != i)
                        {
                            notEnd = 1;
                            // 该样本放入新的类
                            newResult[newClass].push_back(t[j]);
                        }
                        else
                        {
                            // 该样本仍然放入旧的类
                            newResult[i].push_back(t[j]);
                        }

                    }
                }

                // (3) 计算出来的结果保存下来
                for(uint32_t i=0; i< m_k; ++i)
                {
                    m_result[i].clear();
                }

                for(uint32_t i=0; i< m_k; ++i)
                {
                    vector< CElement>& t = newResult[i];
                    uint32_t tSize = t.size();
                    for(uint32_t j=0; j< tSize; ++j)
                    {
                        m_result[i].push_back(t[j]);
                    }
                }

                // (4) 如果迭代没有结束while(notEnd)进入下一次迭代
            }
        }

        // 步骤三:输出结果
        int32_t PrintResult()
        {
            for(uint32_t i=0; i< m_k; ++i)
            {
                printf("result %u",i);
                m_center[i].PrintSelf();
                printf(":");
                vector< CElement>& t = m_result[i];
                uint32_t tSize = t.size();
                for(uint32_t j=0; j< tSize; ++j)
                {
                    t[j].PrintSelf();
                }
                printf("\n");
            }
            return 0;
        }

    private:
        uint32_t m_k;
        vector< vector< CElement> > m_result;
        vector< CElement> m_center;

    private:
        // 内部函数:计算两个样本的距离,具体的实现方法是业务方定义的
        uint32_t GetDistance(CElement a, CElement b)
        {
            return a.GetDistance(&b);
        }

        // 内部函数:计算某个样本离哪个质心最近,即找到它所属的分类
        uint32_t GetElementNearestCenter(CElement a)
        {
            uint32_t minDistance = (uint32_t)(-1);
            uint32_t minCenter = 0;

            uint32_t nowDistance = 0;
            for(uint32_t i=0; i< m_k; ++i)
            {
                nowDistance = m_center[i].GetDistance(&a);
                if(nowDistance < minDistance)
                {
                    minDistance = nowDistance;
                    minCenter = i;
                }
            }
            return minCenter;
        }

        // 内部函数:打印当前的质心
        void PrintCenters()
        {
            static uint32_t count = 0;
            printf("running %u",count++);
            for(uint32_t i=0;i< m_k;++i)
            {
                m_center[i].PrintSelf();
            }
            printf("\n");
        }
};

第四部分:业务的实现举例

// 该业务是二维平面上的点,需要继承抽样业务类IElement
class CPoint : public IElement
{
public:
    // 必须实现默认无参数构造函数
        CPoint()
        {
            this->x = 0;
            this->y = 0;
        }

    // 也可以定义自己的构造函数
        CPoint(uint32_t x, uint32_t y)
        {
            this->x = x;
            this->y = y;
        }

    // 必须定义虚构造函数
        virtual ~CPoint()
        {
        }

    // 为了配合测试,必须实现PrintSelf
        void PrintSelf()
        {
            printf("(%u,%u)",x,y);
        }

    // 业务的核心,必须实现如何求两个样本之间的距离
        uint32_t GetDistance(IElement* b)
        {
            CPoint* p = dynamic_cast< CPoint*>(b);

            int32_t xDif = x - p->x;
            int32_t yDif = y - p->y;
            return xDif*xDif + yDif*yDif;
        }

    // 必须重载运算符
        CPoint& operator=(const CPoint& in)
        {
            this->x = in.x;
            this->y = in.y;
            return *this;
        }

private:
    // 二维平面上的点
        uint32_t x;
        uint32_t y;
};

第五部分:测试代码

int main()
{
    // 生成测试数据
    vector< CPoint> points;
    points.push_back(CPoint(0,0));
#define NUM (5)
#define DEF (50)
    uint32_t x=0;
    uint32_t y=0;
    srand((unsigned)time(NULL));
    for(uint32_t i=0;i< NUM;++i)
    {
        x = random()%DEF;
        y = random()%DEF;
        points.push_back(CPoint(x,y));
        points.push_back(CPoint(x+DEF,y+DEF));
        points.push_back(CPoint(x+DEF*2,y+DEF*2));
    }

    // 调用k-means并输出结果
    CK_means< CPoint> kmeans;
    kmeans.Init(3, points);
    kmeans.Run();
    kmeans.PrintResult();

    return 0;
}

第六部分:结果输出
[shenjian@dev02 k-means]$ g++ k_means.cpp

[shenjian@dev02 k-means]$ ./a.out
running 0(0,0)(5,47)(55,97)
running 1(0,0)(14,39)(92,97)
result 0(0,0):(0,0)
result 1(14,39):(5,47)(14,39)(35,28)(42,47)(13,36)
result 2(92,97):(55,97)(105,147)(64,89)(114,139)(85,78)(135,128)(92,97)(142,147)(63,86)(113,136)

[shenjian@dev02 k-means]$ ./a.out
running 0(0,0)(25,48)(75,98)
running 1(2,7)(47,34)(102,107)
running 2(2,7)(52,57)(125,112)
running 3(2,7)(52,57)(125,112)
result 0(2,7):(0,0)(25,12)(2,7)
result 1(52,57):(25,48)(48,27)(47,34)(52,57)(75,62)(75,98)
result 2(125,112):(125,148)(125,112)(98,77)(148,127)(97,84)(147,134)(102,107)

附录:所有代码

#include< stdlib.h>
#include< stdio.h>
#include< stdint.h>
#include< time.h>
#include< vector>

// by shenjian 20140620
// qq 396009594
using namespace std;

class IElement
{
    public:
        virtual void PrintSelf() = 0;
        virtual uint32_t GetDistance(IElement* b) = 0;
        IElement()
        {
        }
        virtual ~IElement()
        {
        }
        IElement& operator=(const IElement& in);
};

class CPoint : public IElement
{
    public:
        CPoint()
        {
            this->x = 0;
            this->y = 0;
        }

        CPoint(uint32_t x, uint32_t y)
        {
            this->x = x;
            this->y = y;
        }

        virtual ~CPoint()
        {
        }

        void PrintSelf()
        {
            printf("(%u,%u)",x,y);
        }

        uint32_t GetDistance(IElement* b)
        {
            CPoint* p = dynamic_cast< CPoint*>(b);

            int32_t xDif = x - p->x;
            int32_t yDif = y - p->y;
            return xDif*xDif + yDif*yDif;
        }

        CPoint& operator=(const CPoint& in)
        {
            this->x = in.x;
            this->y = in.y;
            return *this;
        }

    private:
        uint32_t x;
        uint32_t y;
};

template< typename CElement>
class CK_means
{
    public:
        int32_t Init(uint32_t k, const vector< CElement>& data)
        {
            if(k==0 || k==1 || k>=data.size())
            {
                return -1;
            }

            m_k = k;
            m_result.resize(k);
            m_center.resize(k);

            // init center
            for(uint32_t i=0; i< m_k; ++i)
            {
                m_center[i] = data[i];
            }
            PrintCenters();

            // init element classification
            uint32_t dataSize = data.size();
            uint32_t elementClass = 0;
            for(uint32_t i=0; i< dataSize; ++i)
            {
                elementClass = GetElementNearestCenter(data[i]);
                m_result[elementClass].push_back(data[i]);
            }

            return 0;
        }

        int32_t Run()
        {
            uint32_t notEnd = 1;
            while(notEnd)
            {
                //(1)caculate new center for each classification
                for(uint32_t i=0;i< m_k;++i)
                {
                    vector< CElement>& t = m_result[i];
                    uint32_t tSize = t.size();

                    // (1.1)if one is center, keep variance
                    vector< uint32_t> variance;
                    variance.resize(tSize);
                    uint32_t dis = 0;
                    for(uint32_t j=0; j< tSize; ++j)
                    {
                        dis = 0;
                        CElement& ele = t[j];
                        for(uint32_t k=0; k< tSize; ++k)
                        {
                            uint32_t tempDis = GetDistance(ele, t[k]);
                            dis += (tempDis * tempDis);
                        }
                        variance[j] = dis;
                    }

                    // (1.2)which variance is the smallest, which is new center
                    uint32_t minVariance = variance[0];
                    uint32_t newCenterIndex = 0;
                    for(uint32_t m=0; m< tSize; ++m)
                    {
                        if(variance[m] < minVariance)
                        {
                            minVariance = variance[m];
                            newCenterIndex = m;
                        }
                    }
                    m_center[i]=t[newCenterIndex];
                }
                PrintCenters();

                //(2) check is there any classification for new center
                notEnd = 0;
                vector< vector< CElement> > newResult;
                newResult.resize(m_k);
                uint32_t newClass = 0;
                for(uint32_t i=0; i< m_k; ++i)
                {
                    vector< CElement>& t = m_result[i];
                    uint32_t tSize = t.size();
                    for(uint32_t j=0; j< tSize; ++j)
                    {
                        // (2.1)caculate new classification for each element
                        newClass = GetElementNearestCenter(t[j]);
                        // if classification change
                        if(newClass != i)
                        {
                            notEnd = 1;
                            newResult[newClass].push_back(t[j]);
                        }
                        // if not change
                        else
                        {
                            newResult[i].push_back(t[j]);
                        }

                    }
                }

                // (3) refresh result
                for(uint32_t i=0; i< m_k; ++i)
                {
                    m_result[i].clear();
                }

                for(uint32_t i=0; i< m_k; ++i)
                {
                    vector< CElement>& t = newResult[i];
                    uint32_t tSize = t.size();
                    for(uint32_t j=0; j< tSize; ++j)
                    {
                        m_result[i].push_back(t[j]);
                    }
                }

                // (4)start next ireration if notEnd = 1
            }
        }

        int32_t PrintResult()
        {
            for(uint32_t i=0; i< m_k; ++i)
            {
                printf("result %u",i);
                m_center[i].PrintSelf();
                printf(":");
                vector< CElement>& t = m_result[i];
                uint32_t tSize = t.size();
                for(uint32_t j=0; j< tSize; ++j)
                {
                    t[j].PrintSelf();
                }
                printf("\n");
            }
            return 0;
        }

    private:
        uint32_t m_k;
        vector< vector< CElement> > m_result;
        vector< CElement> m_center;

    private:
        // FUNC : caculate distance between a and b
        // IN : element a+b
        // OUT : distance
        uint32_t GetDistance(CElement a, CElement b)
        {
            return a.GetDistance(&b);
        }

        // FUNC : caculate which classification that one element belongs to
        // IN : element a
        // OUT : which classification that a belongs to
        //      return range is [0, k)
        uint32_t GetElementNearestCenter(CElement a)
        {
            uint32_t minDistance = (uint32_t)(-1);
            uint32_t minCenter = 0;

            uint32_t nowDistance = 0;
            for(uint32_t i=0; i< m_k; ++i)
            {
                nowDistance = m_center[i].GetDistance(&a);
                if(nowDistance < minDistance)
                {
                    minDistance = nowDistance;
                    minCenter = i;
                }
            }
            return minCenter;
        }

        // FUNC : for test, print all centers
        void PrintCenters()
        {
            static uint32_t count = 0;
            printf("running %u",count++);
            for(uint32_t i=0;i< m_k;++i)
            {
                m_center[i].PrintSelf();
            }
            printf("\n");
        }
};

int main()
{
    // gen test data
    vector< CPoint> points;
    points.push_back(CPoint(0,0));
#define NUM (5)
#define DEF (50)
    uint32_t x=0;
    uint32_t y=0;
    srand((unsigned)time(NULL));
    for(uint32_t i=0;i< NUM;++i)
    {
        x = random()%DEF;
        y = random()%DEF;
        points.push_back(CPoint(x,y));
        points.push_back(CPoint(x+DEF,y+DEF));
        points.push_back(CPoint(x+DEF*2,y+DEF*2));
    }

    // k-means
    CK_means< CPoint> kmeans;
    kmeans.Init(3, points);
    kmeans.Run();
    kmeans.PrintResult();

    return 0;
}

下面提供了源码的下载,wordpress不允许上传cpp文件,故加了一个.txt的后缀哈。
点击k-means.cpp下载源码。

评论关闭。