Created
July 27, 2017 20:01
-
-
Save codehacken/8b9316e025beeabb082dda4d0654a6fa to your computer and use it in GitHub Desktop.
Agglomerative clustering using Scikit-Learn (with a custom distance metric)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Hierarchial Clustering. | |
The goal of gist is to show to use scikit-learn to perform agglomerative clustering when: | |
1. There is a need for a custom distance metric (like levenshtein distance) | |
2. Use the distance in sklearn's API. | |
Adapted from: sklearn's FAQ. | |
http://scikit-learn.org/stable/faq.html | |
""" | |
# Using editdistance package to calculate the levenhstein distance between strings. | |
import editdistance as edist | |
import numpy as np | |
from sklearn.cluster import AgglomerativeClustering | |
from sklearn.metrics.pairwise import pairwise_distances | |
# Simple data (sample genome string). | |
data = ["ACCTCCTAGAAG", "ACCTACTAGAAGTT", "GAATATTAGGCCGA"] | |
# Custom distance metric and use editdistance. | |
def lev_metric(x, y): | |
return int(edist.eval(data[int(x[0])], data[int(y[0])])) | |
# Reshape the data. | |
X = np.arange(len(data)).reshape(-1, 1) | |
print(X.shape) | |
# Calculate pairwise distances with the new metric. | |
m = pairwise_distances(X, X, metric=lev_metric) | |
print(m) | |
# Perform agglomerative clustering. | |
# The affinity is precomputed (since the distance are precalculated). | |
# Use an 'average' linkage. Use any other apart from 'ward'. | |
agg = AgglomerativeClustering(n_clusters=2, affinity='precomputed', | |
linkage='average') | |
# Use the distance matrix directly. | |
u = agg.fit_predict(m) | |
print(u) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment