Navigable Small Worldによる近似最近傍探索
Small World Networkのグラフ特性を利用したNavigable Small World(NSW)というグラフベースの近似最近傍探索をjuliaで実装します
Navigable Small Worldとは?
上記の画像のようにSmall World Networkの特性を持つグラフベースの検索インデックスからqueryに対して近傍のノードを返すアルゴリズムです。 これを階層的に拡張したHierarchical Navigable Small World(HNSW)は非常に性能が良い近似最近傍探索アルゴリズムとして知られています。
実装
まず、n次元のデータ data
と隣接ノード friend
を持つ Node
構造体を作ります。
using Random using LinearAlgebra using DataStructures using Base mutable struct Node data friend::Set{Node} end function show(io::IO, n::Node) friend_str = join(map((x) -> string(x.data), collect(n.friend)), ", ") println(io, "data: {", n.data, "}, friend: {", friend_str, "}") end
グラフ上のノードにおいて、queryに対してk近傍のデータを検索する knn_search
を実装します。
function knn_search(nodes::Vector{Node}, q, m::Int, k::Int) visited_set = Set() canditates = PriorityQueue() result = PriorityQueue() for _ in 1:m if length(visited_set) == length(nodes) break end tmp_result = Vector{Node}() while true ri = rand(1:length(nodes), 1)[1] node = nodes[ri] if !in(node, visited_set) push!(visited_set, node) enqueue!(canditates, node=>norm(node.data .- q)) push!(tmp_result, node) break end end while true if length(canditates) == 0 break end c = dequeue!(canditates) c_d = norm(c.data .- q) result_collect = collect(result) if length(result_collect) >= k n = result_collect[k] if c_d >= n[2] break end end for node in c.friend if !in(node, visited_set) push!(visited_set, node) enqueue!(canditates, node=>norm(node.data .- q)) push!(tmp_result, node) end end end for node in tmp_result enqueue!(result, node=>norm(node.data .- q)) end end result_collect = collect(result)[1:k] result_v = Vector() for r in result_collect push!(result_v, r[1]) end return result_v end
新しいノードの追加は knn_search
を使ってk近傍のノードを接続します。複数のノード群からnswを構築する nsw_build
も作成しておきます。
初期の knn_search
によって長距離のリンクが作られることがポイントです。
function nearest_neighbor_insert(nodes, new_node, f, w) neighbors = knn_search(nodes, new_node.data, w, f) for node in neighbors push!(node.friend, new_node) push!(new_node.friend, node) end end function nsw_build(nodes, f, w) for new_node in nodes nearest_neighbor_insert(nodes, new_node, f, w) end end
queryに対する最近傍探索は隣接ノードから最も近いノードをgreedyに探索していきます。
function greedy_searh(nodes, q) ri = rand(1:length(nodes), 1)[1] near_node = nodes[ri] min_d = norm(near_node.data .- q) while true break_flg = true for node in near_node.friend d = norm(node.data - q) if d < min_d min_d = d near_node = node break_flg = false end end if break_flg break end end return near_node end
以上で実装は終わりです。 試しに以下のように2次元の範囲[0, 1)の乱数データ1000個に対してテストしてみます。
Random.seed!(1234) n = 1000 # number of node dim = 2 # dimension nodes = [Node(rand(dim), Set()) for i in 1:1000] nsw_build(nodes, 10, 10) q = rand(dim) println("query: ", q) res = greedy_searh(nodes, q) println("result: ", res.data) l2 = norm(res.data .- q) println("l2 distance: ", l2)
query: [0.10217918147914129, 0.6093148458481783] result: [0.09006444691972337, 0.6142456375838046] l2 distance: 0.013079736258246004
queryに対して近傍のデータが検索できていることが分かります。
gist
こちらにコードをおいておきます
参考
- Malkov, Yury, et al. "Approximate nearest neighbor algorithm based on navigable small world graphs." Information Systems 45 (2014): 61-68.
- Malkov, Yury A., and Dmitry A. Yashunin. "Efficient and robust approximate nearest neighbor search using hierarchical navigable small world graphs." IEEE transactions on pattern analysis and machine intelligence (2018).