Statistical Learning with Python - Clustering

G-Do 2 Tallied Votes 1K Views Share

Suppose you are a medical researcher studying diabetes. Your boss has given you a big chart of data from diabetes patients. Each row of the chart has information for one patient. Each column of the chart is a health-related statistic, such as height, weight, age, blood pressure, cholesterol level, etc. We say that each patient's score across the columns is that patient's profile. The chart looks something like this:

Height  Weight  Age     Blood P Chol
70      350     48      90      35.5
58      210     53      110     42.1
63      225     44      95      22.2
and so on

So the profile for patient one is just the first row, the profile for the second patient is the second row, and so on. Your job is to figure out if there is some tight range of profiles that is overrepresented by the patient data on the chart, because you want to warn other people who fall into that range to get checked out for diabetes.

As it turns out, there is a really simple way of doing this. If you treat each patient as a point in five-dimensional space, where each dimension corresponds to a column in the chart, all you need to do is find tight spatial clusters of points, and this is a problem which has long been solved. Enter the clustering algorithms.

The module I have contributed contains two major classes: Point and Cluster, which (not surprisingly) allow you to model points and clusters of points. Once you have transformed your data into a list of my generic Point objects, pass it into either the kmeans or agglo functions, which perform k-means clustering and agglomerative clustering, respectively. The output in either case is a list of Cluster objects, which you may then process as you see fit.

A word on the inputs to the algorithms:

K-means should only be used when you have some expectation about the number of clusters you want to get back. This is the "k" input parameter. The k-means algorithm is an iterative algorithm, which means that it will run forever until the biggest centroid shift is smaller than your "cutoff" input parameter. In general, if your data is tightly packed and you want to make fine distinctions between clusters, use a smaller cutoff.

The agglomerative algorithm can be used even if you have no idea how many clusters you should end up with. It takes two parameters: the linkage type (currently either 's' for single linkage, 'c' for complete linkage, or 't' for centroid linkage) and the cutoff. In each iteration, the algorithm computes the pair of clusters with the smallest distance between them and fuses them, until either all the clusters have been fused into one mega-cluster or until the smallest distance is bigger than your "cutoff" input parameter. The distances between clusters are computed by linkage type: single linkage means "the distance between the closest pair of points, one from each cluster," complete linkage means "the distance between the farthest pair of points, one from each cluster," and centroid linkage is simply the distance between the cluster centroids.

The main function tests out these methods by generating 10 random points in two dimensions, then passing them into both clustering algorithms and printing the output clusters.

For more information on clustering, see Brian Luke's website:

http://fconyx.ncifcrf.gov/~lukeb/curact.html

# clustering.py contains classes and functions that cluster data points
import sys, math, random

# -- The Point class represents points in n-dimensional space
class Point:

    # Instance variables
    # self.coords is a list of coordinates for this Point
    # self.n is the number of dimensions this Point lives in (ie, its space)
    # self.reference is an object bound to this Point
    # Initialize new Points
    def __init__(self, coords, reference=None):
        self.coords = coords
        self.n = len(coords)
        self.reference = reference

    # Return a string representation of this Point
    def __repr__(self):
        return str(self.coords)

# -- The Cluster class represents clusters of points in n-dimensional space
class Cluster:

    # Instance variables
    # self.points is a list of Points associated with this Cluster
    # self.n is the number of dimensions this Cluster's Points live in
    # self.centroid is the sample mean Point of this Cluster
    # Initialize new Clusters
    def __init__(self, points):

        # We forbid empty Clusters (they don't make mathematical sense!)
        if len(points) == 0: raise Exception("ILLEGAL: EMPTY CLUSTER")
        self.points = points
        self.n = points[0].n

        # We also forbid Clusters containing Points in different spaces
        # Ie, no Clusters with 2D Points and 3D Points
        for p in points:
            if p.n != self.n: raise Exception("ILLEGAL: MULTISPACE CLUSTER")

        # Figure out what the centroid of this Cluster should be
        self.centroid = self.calculateCentroid()

    # Return a string representation of this Cluster
    def __repr__(self):
        return str(self.points)

    # Update function for the K-means algorithm
    # Assigns a new list of Points to this Cluster, returns centroid difference
    def update(self, points):
        old_centroid = self.centroid
        self.points = points
        self.centroid = self.calculateCentroid()
        return getDistance(old_centroid, self.centroid)

    # Calculates the centroid Point - the centroid is the sample mean Point
    # (in plain English, the average of all the Points in the Cluster)
    def calculateCentroid(self):
        centroid_coords = []

        # For each coordinate:
        for i in range(self.n):

            # Take the average across all Points
            centroid_coords.append(0.0)
            for p in self.points:
                centroid_coords[i] = centroid_coords[i]+p.coords[i]
            centroid_coords[i] = centroid_coords[i]/len(self.points)

        # Return a Point object using the average coordinates
        return Point(centroid_coords)

    # Return the single-linkage distance between this and another Cluster
    def getSingleDistance(self, cluster):
        ret = getDistance(self.points[0], cluster.points[0])
        for p in self.points:
            for q in cluster.points:
                distance = getDistance(p, q)
                if distance < ret: ret = distance
        return ret

    # Return the complete-linkage distance between this and another Cluster
    def getCompleteDistance(self, cluster):
        ret = getDistance(self.points[0], cluster.points[0])
        for p in self.points:
            for q in cluster.points:
                distance = getDistance(p, q)
                if distance > ret: ret = distance
        return ret

    # Return the centroid-linkage distance between this and another Cluster
    def getCentroidDistance(self, cluster):
        return getDistance(self.centroid, cluster.centroid)

    # Return the fusion of this and another Cluster
    def fuse(self, cluster):

        # Forbid fusion of Clusters in different spaces
        if self.n != cluster.n: raise Exception("ILLEGAL FUSION")
        points = self.points
        points.extend(cluster.points)
        return Cluster(points)

# -- Return Clusters of Points formed by K-means clustering
def kmeans(points, k, cutoff):

    # Randomly sample k Points from the points list, build Clusters around them
    initial = random.sample(points, k)
    clusters = []
    for p in initial: clusters.append(Cluster([p]))

    # Enter the program loop
    while True:

        # Make a list for each Cluster
        lists = []
        for c in clusters: lists.append([])

        # For each Point:
        for p in points:

            # Figure out which Cluster's centroid is the nearest
            smallest_distance = getDistance(p, clusters[0].centroid)
            index = 0
            for i in range(len(clusters[1:])):
                distance = getDistance(p, clusters[i+1].centroid)
                if distance < smallest_distance:
                    smallest_distance = distance
                    index = i+1

            # Add this Point to that Cluster's corresponding list
            lists[index].append(p)

        # Update each Cluster with the corresponding list
        # Record the biggest centroid shift for any Cluster
        biggest_shift = 0.0
        for i in range(len(clusters)):
            shift = clusters[i].update(lists[i])
            biggest_shift = max(biggest_shift, shift)

        # If the biggest centroid shift is less than the cutoff, stop
        if biggest_shift < cutoff: break

    # Return the list of Clusters
    return clusters

# -- Return a distance matrix which captures distances between all Clusters
def makeDistanceMatrix(clusters, linkage):
    ret = dict()
    for i in range(len(clusters)):
        for j in range(len(clusters)):
            if j == i: break
            if linkage == 's':
                ret[(i,j)] = clusters[i].getSingleDistance(clusters[j])
            elif linkage == 'c':
                ret[(i,j)] = clusters[i].getCompleteDistance(clusters[j])
            elif linkage == 't':
                ret[(i,j)] = clusters[i].getCentroidDistance(clusters[j])
            else: raise Exception("INVALID LINKAGE")
    return ret

# -- Return Clusters of Points formed by agglomerative clustering
def agglo(points, linkage, cutoff):

    # Currently, we only allow single, complete, or average linkage
    if not linkage in [ 's', 'c', 't' ]: raise Exception("INVALID LINKAGE")

    # Create singleton Clusters, one for each Point
    clusters = []
    for p in points: clusters.append(Cluster([p]))

    # Set the min_distance between Clusters to zero
    min_distance = 0

    # Loop until the break statement is made
    while (True):

        # Compute a distance matrix for all Clusters
        distances = makeDistanceMatrix(clusters, linkage)

        # Find the key for the Clusters which are closest together
        min_key = distances.keys()[0]
        min_distance = distances[min_key]
        for key in distances.keys():
            if distances[key] < min_distance:
                min_key = key
                min_distance = distances[key]

        # If the min_distance is bigger than the cutoff, terminate the loop
        # Otherwise, agglomerate the closest clusters
        if min_distance > cutoff or len(clusters) == 1: break
        else:
            c1, c2 = clusters[min_key[0]], clusters[min_key[1]]
            clusters.remove(c1)
            clusters.remove(c2)
            clusters.append(c1.fuse(c2))

    # Return the list of Clusters
    return clusters

# -- Get the Euclidean distance between two Points
def getDistance(a, b):

    # Forbid measurements between Points in different spaces
    if a.n != b.n: raise Exception("ILLEGAL: NON-COMPARABLE POINTS")

    # Euclidean distance between a and b is sqrt(sum((a[i]-b[i])^2) for all i)
    ret = 0.0
    for i in range(a.n):
        ret = ret+pow((a.coords[i]-b.coords[i]), 2)
    return math.sqrt(ret)

# -- Create a random Point in n-dimensional space
def makeRandomPoint(n, lower, upper):
    coords = []
    for i in range(n): coords.append(random.uniform(lower, upper))
    return Point(coords)

# -- Plot Clusters using Tkinter
def plot(clusters):
    root = Tk()
    cp = ClusterPlot(root)
    root.mainLoop()

# -- Main function
def main(args):
    num_points, n, lower, upper = 10, 2, -200, 200
    k, kmeans_cutoff = 3, 0.5
    linkage, agglo_cutoff = 's', 150.0

    # Create num_points random Points in n-dimensional space, print them
    print "\nPOINTS:"
    points = []
    for i in range(num_points):
        p = makeRandomPoint(n, lower, upper)
        points.append(p)
        print "P:", p

    # Cluster the points using the K-means algorithm, print the results
    clusters = kmeans(points, k, kmeans_cutoff)
    print "\nK-MEANS\nCLUSTERS:"
    for c in clusters: print "C:", c

    # Cluster the points using the agglomerative algorithm, print the results
    clusters = agglo(points, linkage, agglo_cutoff)
    print "\nAGGLOMERATIVE\nCLUSTERS:"
    for c in clusters: print "C:", c

# -- The following code executes upon command-line invocation
if __name__ == "__main__": main(sys.argv)