suzuzusu日記

(´・ω・`)

Navigable Small Worldによる近似最近傍探索

Small World Networkのグラフ特性を利用したNavigable Small World(NSW)というグラフベースの近似最近傍探索をjuliaで実装します

Navigable Small Worldとは?

f:id:suzuzusu:20201214014102p:plain

上記の画像のようにSmall World Networkの特性を持つグラフベースの検索インデックスからqueryに対して近傍のノードを返すアルゴリズムです。 これを階層的に拡張したHierarchical Navigable Small World(HNSW)は非常に性能が良い近似最近傍探索アルゴリズムとして知られています。

実装

まず、n次元のデータ data と隣接ノード friend を持つ Node 構造体を作ります。

using Random
using LinearAlgebra
using DataStructures
using Base

mutable struct Node
    data
    friend::Set{Node}
end

function show(io::IO, n::Node)
    friend_str = join(map((x) -> string(x.data), collect(n.friend)), ", ")
    println(io, "data: {", n.data, "}, friend: {", friend_str, "}")
end

グラフ上のノードにおいて、queryに対してk近傍のデータを検索する knn_search を実装します。

function knn_search(nodes::Vector{Node}, q, m::Int, k::Int)
    visited_set = Set()
    canditates = PriorityQueue()
    result = PriorityQueue()
    for _ in 1:m
        if length(visited_set) == length(nodes)
            break
        end
        tmp_result = Vector{Node}()
        while true
            ri = rand(1:length(nodes), 1)[1]
            node = nodes[ri]
            if !in(node, visited_set)
                push!(visited_set, node)
                enqueue!(canditates, node=>norm(node.data .- q))
                push!(tmp_result, node)
                break
            end
        end
        while true
            if length(canditates) == 0
                break
            end
            c = dequeue!(canditates)
            c_d = norm(c.data .- q)
            result_collect = collect(result)
            if length(result_collect) >= k
                n = result_collect[k]
                if c_d >= n[2]
                    break
                end
            end
            for node in c.friend
                if !in(node, visited_set)
                    push!(visited_set, node)
                    enqueue!(canditates, node=>norm(node.data .- q))
                    push!(tmp_result, node)
                end
            end
        end
        for node in tmp_result
            enqueue!(result, node=>norm(node.data .- q))
        end
    end
    result_collect = collect(result)[1:k]
    result_v = Vector()
    for r in result_collect
        push!(result_v, r[1])
    end
    return result_v
end

新しいノードの追加は knn_search を使ってk近傍のノードを接続します。複数のノード群からnswを構築する nsw_build も作成しておきます。 初期の knn_search によって長距離のリンクが作られることがポイントです。

function nearest_neighbor_insert(nodes, new_node, f, w)
    neighbors = knn_search(nodes, new_node.data, w, f)
    for node in neighbors
        push!(node.friend, new_node)
        push!(new_node.friend, node)
    end
end

function nsw_build(nodes, f, w)
    for new_node in nodes
        nearest_neighbor_insert(nodes, new_node, f, w)
    end
end

queryに対する最近傍探索は隣接ノードから最も近いノードをgreedyに探索していきます。

function greedy_searh(nodes, q)
    ri = rand(1:length(nodes), 1)[1]
    near_node = nodes[ri]
    min_d = norm(near_node.data .- q)
    while true
        break_flg = true
        for node in near_node.friend 
            d = norm(node.data - q)
            if d < min_d
                min_d = d
                near_node = node
                break_flg = false
            end
        end
        if break_flg
            break
        end
    end
    return near_node
end

以上で実装は終わりです。 試しに以下のように2次元の範囲[0, 1)の乱数データ1000個に対してテストしてみます。

Random.seed!(1234)
n = 1000 # number of node
dim = 2 # dimension

nodes = [Node(rand(dim), Set()) for i in 1:1000]
nsw_build(nodes, 10, 10)

q = rand(dim)
println("query: ", q)
res = greedy_searh(nodes, q)
println("result: ", res.data)
l2 = norm(res.data .- q)
println("l2 distance: ", l2)
query: [0.10217918147914129, 0.6093148458481783]
result: [0.09006444691972337, 0.6142456375838046]
l2 distance: 0.013079736258246004

queryに対して近傍のデータが検索できていることが分かります。

gist

こちらにコードをおいておきます

gist.github.com

参考