#!/usr/bin/env python3
"""
Create a list of regions by splitting a reference based on the amount of data in bam files.
Uses the `bai` index of the bam files. Useful for submitting jobs of equal size to a cluster.
"""
import sys
import os
import argparse
import time
import logging
import struct
import numpy as np
from scipy import interpolate
DEFAULT_LOGGING_LEVEL = logging.INFO
MAX_LOGGING_LEVEL = logging.CRITICAL
def setup_logger(verbose_level):
fmt=('%(levelname)s %(asctime)s [%(module)s:%(lineno)s %(funcName)s] :: '
'%(message)s')
logging.basicConfig(format=fmt, level=max((0, min((MAX_LOGGING_LEVEL,
DEFAULT_LOGGING_LEVEL-(verbose_level*10))))))
def Main(argv):
tic_total = time.time()
# parse arguments
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('bamfiles', metavar='BAMFILE', nargs='*')
parser.add_argument('-L', '--bam-list', nargs='*')
parser.add_argument('-r', '--reference-fai', help="reference fasta index file", required=True)
parser.add_argument('-s', '--target-data-size', default='100e6', help="target combined data size of bam files in each region (MB)")
parser.add_argument('--bai-interval-size', default=16384, type=int, help="Size in baseparis of each interval in the bam index (bai).")
parser.add_argument('-v', '--verbose', action='count', default=0,
help="increase logging verbosity")
parser.add_argument('-q', '--quiet', action='count', default=0,
help="decrease logging verbosity")
args = parser.parse_args(argv)
# setup logger
setup_logger(verbose_level=args.verbose-args.quiet)
if argv is not None:
logging.warning('Using passed arguments: '+str(argv))
logging.info('args: '+str(args))
# additional argument parsing and datatype handling
if not args.bamfiles and not args.bam_list:
logging.error("Must provide an BAMFILE and/or --bam-list argument")
sys.exit(2)
args.target_data_size = int(float(args.target_data_size))*1000000
logging.info('target-data-size: '+str(args.target_data_size)+' bytes')
# read bam-lists if provided
if args.bam_list:
for bamlistfile in args.bam_list:
with open(bamlistfile,'r') as fh:
for x in fh:
x = x.split('#')[0].strip()
if x:
args.bamfiles.append(x)
#logging.info('bam files: '+", ".join(args.bamfiles)) # output complete list of bam files being used
# read the reference fasta index
fai_chrom = []
fai_len = []
with open(args.reference_fai,'r') as fh:
for x in fh:
x = x.strip().split(sep='\t')
fai_chrom.append(str(x[0]))
fai_len.append(int(x[1]))
## read bai indexes, skipping bin info
# list by chrom of number of intervals
n_intvs = np.array([int(np.ceil(x/args.bai_interval_size)) for x in fai_len])
# list by chrom of lists of interval offsets
icumsz = [] # cumulative size of data by interval
for i,n in enumerate(n_intvs):
icumsz.append(np.zeros((n,), dtype=np.int64))
for bamfn in args.bamfiles:
baifn = bamfn+'.bai'
with open(baifn,'rb') as fh:
logging.info("processing: "+baifn)
# filetype magic check
assert struct.unpack('4s', fh.read(4))[0] == b'BAI\x01'
# number of reference sequences (chroms)
n_ref = struct.unpack('i', fh.read(4))[0]
assert n_ref == len(fai_len), "fasta index and bam index have must have same number of chroms"
for ci in range(n_ref):
# skip over the binning index
n_bin = struct.unpack('i', fh.read(4))[0]
for bini in range(n_bin):
bin_id = struct.unpack('I', fh.read(4))[0]
n_chunk = struct.unpack('i', fh.read(4))[0]
fh.seek(n_chunk*16, os.SEEK_CUR)
# read interval index
n_intv = struct.unpack('i', fh.read(4))[0]
if n_intv > 0:
ioff = np.array(struct.unpack(str(n_intv)+'Q', fh.read(n_intv*8)), dtype=np.int64)
while( len(ioff) < len(icumsz[ci]) ):
ioff = np.append(ioff, ioff[-1]+1)
icumsz[ci] += ioff-ioff[0]
## make the list of regions
regions = []
for ci,chrom in enumerate(fai_chrom):
# sanity check last point if there are more than one
if len(icumsz[ci]) > 1:
assert icumsz[ci][-1] >= icumsz[ci][-2]
# tiny chroms just get 1 region
if len(icumsz[ci]) < 2:
regions.extend([ (fai_chrom[ci], 0, fai_len[ci]) ])
continue
ds = icumsz[ci]
pos = np.arange(0, ds.shape[0])*args.bai_interval_size
# estimate total data size for the chrom
f = interpolate.interp1d(pos, ds, fill_value='extrapolate', kind='linear')
ds_total = f([fai_len[ci]])[0]
num_regions = int(np.ceil(ds_total/args.target_data_size))
# approx equal LENGTH regions
# tmp = np.linspace(0, fai_len[ci], num=num_regions+1, endpoint=True, dtype=int)
# approx equal DATA SIZE regions
f = interpolate.interp1d(ds, pos, fill_value='extrapolate', kind='linear')
dsx = np.linspace(0, ds_total, num=num_regions+1, endpoint=True, dtype=int)
tmp = f(dsx).astype(int)
# ensure we exactly hit the endpoints
tmp[0] = 0
tmp[-1] = fai_len[ci]
regions.extend([ (fai_chrom[ci], tmp[i], tmp[i+1]) for i in range(len(tmp)-1) ])
## Output regions file
for r in regions:
print(*r, sep='\t')
logging.info("Number of chroms: {}".format(len(fai_len)))
logging.info("Number of splits: {}".format(len(regions)-len(fai_len)))
logging.info("Number of regions: {}".format(len(regions)))
logging.info("Done: {:.2f} sec elapsed".format(time.time()-tic_total))
return 0
#########################################################################
# Main loop hook... if run as script run main, else this is just a module
if __name__ == '__main__':
if 'TESTING_ARGS' in globals():
sys.exit(Main(argv=TESTING_ARGS))
else:
sys.exit(Main(argv=None))