142 lines
4.1 KiB
Python
142 lines
4.1 KiB
Python
# Copyright (c) 2022, Manfred Moitzi
|
|
# License: MIT License
|
|
from __future__ import annotations
|
|
from typing import Iterator, Iterable, Optional
|
|
import random
|
|
import statistics
|
|
import itertools
|
|
import operator
|
|
from collections import defaultdict
|
|
from functools import reduce
|
|
|
|
from ezdxf.math import AnyVec, Vec3, spherical_envelope
|
|
from ezdxf.math.rtree import RTree
|
|
|
|
__all__ = [
|
|
"dbscan",
|
|
"k_means",
|
|
"average_cluster_radius",
|
|
"average_intra_cluster_distance",
|
|
]
|
|
|
|
|
|
def dbscan(
|
|
points: list[AnyVec],
|
|
*,
|
|
radius: float,
|
|
min_points: int = 4,
|
|
rtree: Optional[RTree] = None,
|
|
max_node_size: int = 5,
|
|
) -> list[list[AnyVec]]:
|
|
"""DBSCAN clustering.
|
|
|
|
https://en.wikipedia.org/wiki/DBSCAN
|
|
|
|
Args:
|
|
points: list of points to cluster
|
|
radius: radius of the dense regions
|
|
min_points: minimum number of points that needs to be within the
|
|
`radius` for a point to be a core point (must be >= 2)
|
|
rtree: optional :class:`~ezdxf.math.rtree.RTree`
|
|
max_node_size: max node size for internally created RTree
|
|
|
|
Returns:
|
|
list of clusters, each cluster is a list of points
|
|
|
|
"""
|
|
if min_points < 2:
|
|
raise ValueError("min_points must be >= 2")
|
|
if rtree is None:
|
|
rtree = RTree(points, max_node_size)
|
|
|
|
clusters: list[set[AnyVec]] = []
|
|
point_set = set(points)
|
|
while len(point_set):
|
|
point = point_set.pop()
|
|
todo = {point}
|
|
cluster = {point} # the cluster has only a single entry if noise
|
|
clusters.append(cluster)
|
|
while len(todo):
|
|
chk_point = todo.pop()
|
|
neighbors = set(rtree.points_in_sphere(chk_point, radius))
|
|
if len(neighbors) < min_points:
|
|
continue
|
|
cluster.add(chk_point)
|
|
point_set.discard(chk_point)
|
|
todo |= neighbors.intersection(point_set)
|
|
|
|
return [list(cluster) for cluster in clusters]
|
|
|
|
|
|
def k_means(
|
|
points: list[AnyVec], k: int, max_iter: int = 10
|
|
) -> list[list[AnyVec]]:
|
|
"""K-means clustering.
|
|
|
|
https://en.wikipedia.org/wiki/K-means_clustering
|
|
|
|
Args:
|
|
points: list of points to cluster
|
|
k: number of clusters
|
|
max_iter: max iterations
|
|
|
|
Returns:
|
|
list of clusters, each cluster is a list of points
|
|
|
|
"""
|
|
|
|
def classify(centroids: Iterable[AnyVec]):
|
|
new_clusters: dict[AnyVec, list[AnyVec]] = defaultdict(list)
|
|
tree = RTree(centroids)
|
|
for point in points:
|
|
nn, _ = tree.nearest_neighbor(point)
|
|
new_clusters[nn].append(point)
|
|
return new_clusters
|
|
|
|
def recenter() -> Iterator[AnyVec]:
|
|
for cluster_points in clusters.values():
|
|
yield Vec3.sum(cluster_points) / len(cluster_points)
|
|
if len(clusters) < k: # refill centroids if required
|
|
yield from random.sample(points, k - len(clusters))
|
|
|
|
def is_equal_clustering(old_clusters, new_clusters):
|
|
def hash_list(lst):
|
|
lst.sort()
|
|
return reduce(operator.xor, map(hash, lst))
|
|
|
|
h1 = sorted(map(hash_list, old_clusters.values()))
|
|
h2 = sorted(map(hash_list, new_clusters.values()))
|
|
return h1 == h2
|
|
|
|
if not (1 < k < len(points)):
|
|
raise ValueError(
|
|
"invalid argument k: must be in range [2, len(points)-1]"
|
|
)
|
|
clusters: dict[AnyVec, list[AnyVec]] = classify(random.sample(points, k))
|
|
for _ in range(max_iter):
|
|
new_clusters = classify(recenter())
|
|
if is_equal_clustering(clusters, new_clusters):
|
|
break
|
|
clusters = new_clusters
|
|
return list(clusters.values())
|
|
|
|
|
|
def average_intra_cluster_distance(clusters: list[list[AnyVec]]) -> float:
|
|
"""Returns the average point-to-point intra cluster distance."""
|
|
|
|
return statistics.mean(
|
|
[
|
|
p.distance(q)
|
|
for cluster in clusters
|
|
for (p, q) in itertools.combinations(cluster, 2)
|
|
]
|
|
)
|
|
|
|
|
|
def average_cluster_radius(clusters: list[list[AnyVec]]) -> float:
|
|
"""Returns the average cluster radius."""
|
|
|
|
return statistics.mean(
|
|
[spherical_envelope(cluster)[1] for cluster in clusters]
|
|
)
|