如果 KNN 淘汰了,那么取而代之的将是 ANN
点击上方“程序员大白”,选择“星标”公众号
重磅干货,第一时间送达
选自 | towardsdatascience
作者 | Marie Stephen Leo
转自 | 机器之心
编辑 | 小舟、杜伟
Spotify 的 ANNOY
Google 的 ScaNN
Facebook 的 Faiss
HNSW
import hnswlib
import numpy as npdef fit_hnsw_index(features, ef=100, M=16, save_index_file=False):
# Convenience function to create HNSW graph
# features : list of lists containing the embeddings
# ef, M: parameters to tune the HNSW algorithm
num_elements = len(features)
labels_index = np.arange(num_elements) EMBEDDING_SIZE = len(features[0]) # Declaring index
# possible space options are l2, cosine or ip
p = hnswlib.Index(space='l2', dim=EMBEDDING_SIZE) # Initing index - the maximum number of elements should be known
p.init_index(max_elements=num_elements, ef_construction=ef, M=M) # Element insertion
int_labels = p.add_items(features, labels_index) # Controlling the recall by setting ef
# ef should always be > k
p.set_ef(ef)
# If you want to save the graph to a file
if save_index_file:
p.save_index(save_index_file)
return p
ann_neighbor_indices, ann_distances = p.knn_query(features, k)
# Imports
# For input data pre-processing
import json
import gzip
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import fasttext.util
fasttext.util.download_model('en', if_exists='ignore') # English pre-trained model
ft = fasttext.load_model('cc.en.300.bin')# For KNN vs ANN benchmarking
from datetime import datetime
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
import hnswlib
# Data: http://deepyeti.ucsd.edu/jianmo/amazon/
data = []
with gzip.open('meta_Cell_Phones_and_Accessories.json.gz') as f:
for l in f:
data.append(json.loads(l.strip()))# Pre-Processing: https://colab.research.google.com/drive/1Zv6MARGQcrBbLHyjPVVMZVnRWsRnVMpV#scrollTo=LgWrDtZ94w89
# Convert list into pandas dataframe
df = pd.DataFrame.from_dict(data)
df.fillna('', inplace=True)# Filter unformatted rows
df = df[~df.title.str.contains('getTime')]# Restrict to just 'Cell Phones and Accessories'
df = df[df['main_cat']=='Cell Phones & Accessories']# Reset index
df.reset_index(inplace=True, drop=True)# Only keep the title columns
df = df[['title']]# Check the df
print(df.shape)
df.head()
# Title Embedding using FastText Sentence Embedding
df['emb'] = df['title'].apply(ft.get_sentence_vector)# Extract out the embeddings column as a list of lists for input to our NN algos
X = [item.tolist() for item in df['emb'].values]
# Number of products for benchmark loop
n_products = [1000, 10000, 100000, len(X)]# Number of neighbors for benchmark loop
n_neighbors = [10, 100]# Dictionary to save metric results for each iteration
metrics = {'products':[], 'k':[], 'knn_time':[], 'ann_time':[], 'pct_overlap':[]}for products in tqdm(n_products):
# "products" number of products included in the search space
features = X[:products]
for k in tqdm(n_neighbors):
# "K" Nearest Neighbor search
# KNN
knn_start = datetime.now()
nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(features)
knn_distances, knn_neighbor_indices = nbrs.kneighbors(X)
knn_end = datetime.now()
metrics['knn_time'].append((knn_end - knn_start).total_seconds())
# HNSW ANN
ann_start = datetime.now()
p = fit_hnsw_index(features, ef=k*10)
ann_neighbor_indices, ann_distances = p.knn_query(features, k)
ann_end = datetime.now()
metrics['ann_time'].append((ann_end - ann_start).total_seconds())
# Average Percent Overlap in Nearest Neighbors across all "products"
metrics['pct_overlap'].append(np.mean([len(np.intersect1d(knn_neighbor_indices[i], ann_neighbor_indices[i]))/k for i in range(len(features))]))
metrics['products'].append(products)
metrics['k'].append(k)
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv('metrics_df.csv', index=False)
metrics_df
推荐阅读
关于程序员大白
程序员大白是一群哈工大,东北大学,西湖大学和上海交通大学的硕士博士运营维护的号,大家乐于分享高质量文章,喜欢总结知识,欢迎关注[程序员大白],大家一起学习进步!
评论