diff --git a/pyserini/encode/__main__.py b/pyserini/encode/__main__.py index 6fe5a8403..afc4ded9a 100644 --- a/pyserini/encode/__main__.py +++ b/pyserini/encode/__main__.py @@ -33,9 +33,8 @@ "openai-api": OpenAIDocumentEncoder, "auto": AutoDocumentEncoder, } -ALLOWED_POOLING_OPTS = ["cls","mean"] -def init_encoder(encoder, encoder_class, device): +def init_encoder(encoder, encoder_class, device, pooling, l2_norm, prefix): _encoder_class = encoder_class # determine encoder_class @@ -52,6 +51,7 @@ def init_encoder(encoder, encoder_class, device): # if none of the class keyword was matched, # use the AutoDocumentEncoder if encoder_class is None: + _encoder_class = "auto" encoder_class = AutoDocumentEncoder # prepare arguments to encoder class @@ -60,6 +60,8 @@ def init_encoder(encoder, encoder_class, device): kwargs.update(dict(pooling='mean', l2_norm=True)) if (_encoder_class == "contriever") or ("contriever" in encoder): kwargs.update(dict(pooling='mean', l2_norm=False)) + if (_encoder_class == "auto"): + kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix)) return encoder_class(**kwargs) @@ -117,19 +119,16 @@ def parse_args(parser, commands): default='cuda:0', required=False) encoder_parser.add_argument('--fp16', action='store_true', default=False) encoder_parser.add_argument('--add-sep', action='store_true', default=False) - encoder_parser.add_argument('--pooling', type=str, default='cls', help='for auto classes, allow the ability to dictate pooling strategy', required=False) + encoder_parser.add_argument('--pooling', type=str, default='cls', help='for auto classes, allow the ability to dictate pooling strategy', choices=['cls', 'mean'], required=False) + encoder_parser.add_argument('--l2-norm', action='store_true', help='whether to normalize embedding', default=False, required=False) + encoder_parser.add_argument('--prefix', type=str, help='prefix of document input', default=None, required=False) encoder_parser.add_argument('--use-openai', help='use OpenAI text-embedding-ada-002 to retreive embeddings', action='store_true', default=False) encoder_parser.add_argument('--rate-limit', type=int, help='rate limit of the requests per minute for OpenAI embeddings', default=3500, required=False) args = parse_args(parser, commands) delimiter = args.input.delimiter.replace("\\n", "\n") # argparse would add \ prior to the passed '\n\n' - encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device) - if type(encoder).__name__ == "AutoDocumentEncoder": - if args.encoder.pooling in ALLOWED_POOLING_OPTS: - encoder.pooling = args.encoder.pooling - else: - raise ValueError(f"Only allowed to use pooling types {ALLOWED_POOLING_OPTS}. You entered {args.encoder.pooling}") + encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device, pooling=args.encoder.pooling, l2_norm=args.encoder.l2_norm, prefix=args.encoder.prefix) if args.output.to_faiss: embedding_writer = FaissRepresentationWriter(args.output.embeddings, dimension=args.encoder.dimension) else: diff --git a/pyserini/encode/_auto.py b/pyserini/encode/_auto.py index 5e8cf6cd1..d61ed5c26 100644 --- a/pyserini/encode/_auto.py +++ b/pyserini/encode/_auto.py @@ -22,7 +22,7 @@ class AutoDocumentEncoder(DocumentEncoder): - def __init__(self, model_name, tokenizer_name=None, device='cuda:0', pooling='cls', l2_norm=False): + def __init__(self, model_name, tokenizer_name=None, device='cuda:0', pooling='cls', l2_norm=False, prefix=None): self.device = device self.model = AutoModel.from_pretrained(model_name) self.model.to(self.device) @@ -33,8 +33,11 @@ def __init__(self, model_name, tokenizer_name=None, device='cuda:0', pooling='cl self.has_model = True self.pooling = pooling self.l2_norm = l2_norm + self.prefix = prefix def encode(self, texts, titles=None, max_length=256, add_sep=False, **kwargs): + if self.prefix is not None: + texts = [f'{self.prefix} {text}' for text in texts] shared_tokenizer_kwargs = dict( max_length=max_length, truncation=True, diff --git a/pyserini/encode/query.py b/pyserini/encode/query.py index 8eb0b15ae..d1e04a784 100644 --- a/pyserini/encode/query.py +++ b/pyserini/encode/query.py @@ -24,7 +24,7 @@ from pyserini.encode import UniCoilQueryEncoder, SpladeQueryEncoder, OpenAIQueryEncoder -def init_encoder(encoder, device): +def init_encoder(encoder, device, pooling, l2_norm, prefix): if 'dpr' in encoder.lower(): return DprQueryEncoder(encoder, device=device) elif 'tct' in encoder.lower(): @@ -40,7 +40,7 @@ def init_encoder(encoder, device): elif 'openai-api' in encoder.lower(): return OpenAIQueryEncoder() else: - return AutoQueryEncoder(encoder, device=device) + return AutoQueryEncoder(encoder, device=device, pooling=pooling, l2_norm=l2_norm, prefix=prefix) if __name__ == '__main__': @@ -54,9 +54,14 @@ def init_encoder(encoder, device): parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]', default='cpu', required=False) parser.add_argument('--max-length', type=int, help='max length', default=256, required=False) + parser.add_argument('--pooling', type=str, help='pooling strategy', default='cls', choices=['cls', 'mean'], + required=False) + parser.add_argument('--l2-norm', action='store_true', help='whether to normalize embedding', default=False, + required=False) + parser.add_argument('--prefx', type=str, help='prefix query input', default=None, required=False) args = parser.parse_args() - encoder = init_encoder(args.encoder, device=args.device) + encoder = init_encoder(args.encoder, device=args.device, pooling=args.pooling, l2_norm=args.l2_norm, prefix=args.prefx) query_iterator = DefaultQueryIterator.from_topics(args.topics) is_sparse = False diff --git a/pyserini/search/faiss/__main__.py b/pyserini/search/faiss/__main__.py index 39e9b5690..2700fdd6b 100644 --- a/pyserini/search/faiss/__main__.py +++ b/pyserini/search/faiss/__main__.py @@ -47,6 +47,11 @@ def define_dsearch_args(parser): parser.add_argument('--encoder', type=str, metavar='path to query encoder checkpoint or encoder name', required=False, help="Path to query encoder pytorch checkpoint or hgf encoder model name") + parser.add_argument('--pooling', type=str, metavar='pooling strategy', required=False, default='cls', + choices=['cls', 'mean'], + help="Pooling strategy for query encoder") + parser.add_argument('--l2-norm', action='store_true', help='whether to normalize embedding', default=False, + required=False) parser.add_argument('--tokenizer', type=str, metavar='name or path', required=False, help="Path to a hgf tokenizer name or path") @@ -85,7 +90,7 @@ def define_dsearch_args(parser): help="Set efSearch for HNSW index") -def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, prefix, max_length): +def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, max_length, pooling, l2_norm, prefix): encoded_queries_map = { 'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset', 'dpr-nq-dev': 'dpr_multi-nq-dev', @@ -126,6 +131,7 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco # if none of the class keyword was matched, # use the AutoQueryEncoder if encoder_class is None: + _encoder_class = "auto" encoder_class = AutoQueryEncoder # prepare arguments to encoder class @@ -136,6 +142,8 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco kwargs.update(dict(pooling='mean', l2_norm=False)) if (_encoder_class == "openai-api") or ("openai" in encoder): kwargs.update(dict(max_length=max_length)) + if (_encoder_class == "auto"): + kwargs.update(dict(pooling=pooling, l2_norm=l2_norm, prefix=prefix)) return encoder_class(**kwargs) if encoded_queries: @@ -188,7 +196,7 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco topics = query_iterator.topics query_encoder = init_query_encoder( - args.encoder, args.encoder_class, args.tokenizer, args.topics, args.encoded_queries, args.device, args.query_prefix, args.max_length) + args.encoder, args.encoder_class, args.tokenizer, args.topics, args.encoded_queries, args.device, args.max_length, args.pooling, args.l2_norm, args.query_prefix) if args.pca_model: query_encoder = PcaEncoder(query_encoder, args.pca_model) kwargs = {}