forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient_server.py
executable file
·91 lines (69 loc) · 2.71 KB
/
client_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing.pool import ThreadPool
import faiss
from typing import List, Tuple
from . import rpc
############################################################
# Server implementation
############################################################
class SearchServer(rpc.Server):
""" Assign version that can be exposed via RPC """
def __init__(self, s: int, index: faiss.Index):
rpc.Server.__init__(self, s)
self.index = index
self.index_ivf = faiss.extract_index_ivf(index)
def set_nprobe(self, nprobe: int) -> int:
""" set nprobe field """
self.index_ivf.nprobe = nprobe
def get_ntotal(self) -> int:
return self.index.ntotal
def __getattr__(self, f):
# all other functions get forwarded to the index
return getattr(self.index, f)
def run_index_server(index: faiss.Index, port: int, v6: bool = False):
""" serve requests for that index forerver """
rpc.run_server(
lambda s: SearchServer(s, index),
port, v6=v6)
############################################################
# Client implementation
############################################################
class ClientIndex:
"""manages a set of distance sub-indexes. The sub_indexes search a
subset of the inverted lists. Searches are merged afterwards
"""
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
""" connect to a series of (host, port) pairs """
self.sub_indexes = []
for machine, port in machine_ports:
self.sub_indexes.append(rpc.Client(machine, port, v6))
self.ni = len(self.sub_indexes)
# pool of threads. Each thread manages one sub-index.
self.pool = ThreadPool(self.ni)
# test connection...
self.ntotal = self.get_ntotal()
self.verbose = False
def set_nprobe(self, nprobe: int) -> None:
self.pool.map(
lambda idx: idx.set_nprobe(nprobe),
self.sub_indexes
)
def set_omp_num_threads(self, nt: int) -> None:
self.pool.map(
lambda idx: idx.set_omp_num_threads(nt),
self.sub_indexes
)
def get_ntotal(self) -> None:
return sum(self.pool.map(
lambda idx: idx.get_ntotal(),
self.sub_indexes
))
def search(self, x, k: int):
rh = faiss.ResultHeap(x.shape[0], k)
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
rh.add_result(Di, Ii)
rh.finalize()
return rh.D, rh.I