
K-Means 算法
K-Means 是一种无监督学习算法,用于将数据点划分为 k
个不同的簇。下面是使用 PyTorch 实现 K-Means 算法的代码示例:
import torch
def kmeans(X, k, max_iter=100):
"""
K-Means 算法的 PyTorch 实现
参数:
X (torch.Tensor): 输入数据,形状为 (n_samples, n_features)
k (int): 簇的数量
max_iter (int): 最大迭代次数
返回:
torch.Tensor: 每个数据点所属的簇的索引
torch.Tensor: 每个簇的中心点
"""
# 随机初始化中心点
n_samples, n_features = X.shape
indices = torch.randperm(n_samples)[:k]
centroids = X[indices]
for _ in range(max_iter):
# 计算每个数据点到每个中心点的距离
distances = torch.cdist(X, centroids)
# 找到每个数据点最近的中心点
labels = torch.argmin(distances, dim=1)
# 更新中心点
new_centroids = torch.zeros(k, n_features, dtype=X.dtype, device=X.device)
for i in range(k):
cluster_points = X[labels == i]
if len(cluster_points) > 0:
new_centroids[i] = torch.mean(cluster_points, dim=0)
else:
# 如果某个簇为空,重新随机初始化该中心点
new_centroids[i] = X[torch.randint(0, n_samples, (1,))]
# 判断是否收敛
if torch.allclose(centroids, new_centroids):
break
centroids = new_centroids
return labels, centroids
你可以使用以下方式调用这个函数:
# 生成一些示例数据
n_samples = 100
n_features = 2
X = torch.randn(n_samples, n_features)
# 运行 K-Means 算法
k = 3
labels, centroids = kmeans(X, k)
print("每个数据点所属的簇的索引:", labels)
print("每个簇的中心点:", centroids)
代码解释:
- 初始化中心点:从输入数据中随机选择
k
个点作为初始中心点。 - 迭代更新:
- 计算每个数据点到每个中心点的距离。
- 找到每个数据点最近的中心点,并将其分配到该簇。
- 更新每个簇的中心点,即计算该簇中所有数据点的平均值。
- 收敛判断:如果中心点不再发生变化,则认为算法收敛,停止迭代。
注意事项:
- 该实现使用欧几里得距离来计算数据点之间的距离。
- 如果某个簇为空,我们重新随机初始化该簇的中心点,以避免出现中心点为零向量的情况。

