Skip to content

Commit

Permalink
Preseq: fix parsing nan and add strict type hints (#3038)
Browse files Browse the repository at this point in the history
* Remove debug print

* Preseq: strict type hint and fix potential bug

* More typing

* More typing

* Fix parsing nan
  • Loading branch information
vladsavelyev authored Jan 4, 2025
1 parent 491502c commit 3834ac5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 66 deletions.
138 changes: 74 additions & 64 deletions multiqc/modules/preseq/preseq.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from typing import List
from typing import Dict, List, Optional, Tuple

import numpy as np

from multiqc import config
from multiqc.base_module import BaseMultiqcModule, ModuleNoSamplesFound
from multiqc.plots import linegraph
from multiqc.plots.plotly.line import Series, Marker
from multiqc.plots.plotly.line import LinePlotConfig, Series, Marker
from multiqc.plots.plotly.plot import PConfig
from multiqc.utils import mqc_colour

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,8 +95,10 @@ def __init__(self):
reads_data = dict()
bases_data = dict()
for f in self.find_log_files("preseq"):
sample_data_raw, sample_data_is_bases = _parse_preseq_logs(f)
if sample_data_raw is None:
try:
sample_data_raw, sample_data_is_bases = _parse_preseq_logs(f)
except ValueError as e:
log.warning(f"Skipping {f['fn']}: {e}")
continue

if f["s_name"] in sample_data_raw:
Expand Down Expand Up @@ -131,7 +134,7 @@ def __init__(self):
if reads_data:
self._make_preseq_length_trimmed_plot(reads_data, False)

def _make_preseq_length_trimmed_plot(self, data_raw, is_basepairs):
def _make_preseq_length_trimmed_plot(self, data_raw: Dict[str, Dict[float, float]], is_basepairs: bool) -> None:
"""Generate the preseq plot.
For Y axis, plot coverages if `config.preseq.read_length` and `config.preseq.genome_size`
Expand All @@ -140,56 +143,65 @@ def _make_preseq_length_trimmed_plot(self, data_raw, is_basepairs):
For X axis, plot counts (or base pairs) in X axis (unless coverages requested
explicitly in `config.preseq.x_axis`).
"""
counts_in_1x = _get_counts_in_1x(is_basepairs)
x_axis = getattr(config, "preseq", {}).get("x_axis", "counts")
y_axis = getattr(config, "preseq", {}).get("y_axis", "coverage" if counts_in_1x else "counts")

# Modify counts
d_cnts = {sn: _modify_raw_data(sample_data, is_basepairs) for sn, sample_data in data_raw.items()}

# Convert counts (base pairs) -> depths
d_covs = {
sn: _counts_to_coverages(sample_data, counts_in_1x) if counts_in_1x else None
for sn, sample_data in data_raw.items()
}

# Prepare final dataset for plotting
data = dict()
for sn, s_d_cnts, s_d_covs in zip(d_cnts, d_cnts.values(), d_covs.values()):
keys = s_d_covs.keys() if x_axis == "coverage" else s_d_cnts.keys()
values = s_d_covs.values() if y_axis == "coverage" else s_d_cnts.values()
data[sn] = dict(zip(keys, values))
# Plot depths of counts
counts_in_1x: Optional[float] = _calc_count_in_1x(is_basepairs)
x_axis = getattr(config, "preseq", {}).get("x_axis", "counts")
y_axis = getattr(config, "preseq", {}).get("y_axis", "coverage" if counts_in_1x is not None else "counts")

# Count maximum values to draw the "ideal" line
max_y_raw, max_sn = max((max(sd.values()), sn) for sn, sd in data_raw.items())
max_y_cnt = list(_modify_raw_data({max_y_raw: max_y_raw}, is_basepairs).items())[0][0]
max_y_cov = list(_counts_to_coverages({max_y_raw: max_y_raw}, counts_in_1x).items())[0][0]
max_y = max_y_cov if y_axis == "coverage" else max_y_cnt
max_yx = max_y_cov if x_axis == "coverage" else max_y_cnt

data: Dict[str, Dict[float, float]]
max_y: float
max_yx: float
if counts_in_1x is not None:
# Convert counts (base pairs) -> depths
d_covs = {sn: _counts_to_coverages(sample_data, counts_in_1x) for sn, sample_data in data_raw.items()}
# Prepare final dataset for plotting
data = dict()
for sn, s_d_covs in zip(d_cnts, d_covs.values()):
keys = s_d_covs.keys()
values = s_d_covs.values()
data[sn] = dict(zip(keys, values))

# Count maximum values to draw the "ideal" line
max_y_cov = list(_counts_to_coverages({max_y_raw: max_y_raw}, counts_in_1x).items())[0][0]
max_y = max_y_cov
max_yx = max_y_cov

else:
d_covs = {sn: {} for sn in data_raw.keys()}
# Prepare final dataset for plotting
data = d_cnts
# Count maximum values to draw the "ideal" line
max_y_cnt = list(_modify_raw_data({max_y_raw: max_y_raw}, is_basepairs).items())[0][0]
max_y = max_y_cnt
max_yx = max_y_cnt

# Preparing axis and tooltip labels
x_suffix, y_tt_lbl, x_axis_name, y_suffix, x_tt_lbl, y_axis_name = _prepare_labels(
is_basepairs, max_y_cov, x_axis, y_axis
is_basepairs, max_y, x_axis, y_axis
)

name = "Complexity curve"
description = ""
section_id = "preseq_plot"
pconfig = {
"id": "preseq_complexity_plot",
"title": "Preseq: Complexity curve",
"xlab": x_axis_name,
"ylab": y_axis_name,
"xmin": 0,
"ymin": 0,
"tt_label": "<b>" + y_tt_lbl + "</b>: " + x_tt_lbl,
"xsuffix": x_suffix,
"ysuffix": y_suffix,
"extra_series": [],
}
pconfig = LinePlotConfig(
id="preseq_complexity_plot",
title="Preseq: Complexity curve",
xlab=x_axis_name,
ylab=y_axis_name,
xmin=0,
ymin=0,
tt_label="<b>" + y_tt_lbl + "</b>: " + x_tt_lbl,
xsuffix=x_suffix,
ysuffix=y_suffix,
)
if not is_basepairs:
pconfig["title"] += " (molecule count)"
pconfig["id"] += "_molecules"
pconfig.title += " (molecule count)"
pconfig.id += "_molecules"
name += " (molecule count)"
section_id += "_molecules"

Expand All @@ -198,8 +210,8 @@ def _make_preseq_length_trimmed_plot(self, data_raw, is_basepairs):
real_vals_all, real_vals_unq = _prep_real_counts(
real_cnts_all, real_cnts_unq, is_basepairs, counts_in_1x, x_axis, y_axis
)
pconfig["extra_series"].extend(
_real_counts_to_plot_series(data, real_vals_unq, real_vals_all, x_suffix, y_suffix, y_tt_lbl)
pconfig.extra_series = _real_counts_to_plot_series(
data, real_vals_unq, real_vals_all, x_suffix, y_suffix, y_tt_lbl
)
if real_vals_unq:
description += "<p>Points show read count versus deduplicated read counts (externally calculated).</p>"
Expand All @@ -210,17 +222,17 @@ def _make_preseq_length_trimmed_plot(self, data_raw, is_basepairs):
if getattr(config, "preseq", {}).get("notrim", False) is not True:
max_y *= 0.8
max_yx *= 0.8
max_x = 0
max_x = 0.0
for x in sorted(list(data[max_sn].keys())):
max_x = max(max_x, x)
if data[max_sn][x] > max_y and x > real_vals_all.get(max_sn, 0) and x > real_vals_unq.get(max_sn, 0):
break
pconfig["xmax"] = max_x
pconfig.xmax = max_x
description += "<p>Note that the x-axis is trimmed at the point where all the datasets \
show 80% of their maximum y-value, to avoid ridiculous scales.</p>"

# Plot perfect library as dashed line
pconfig["extra_series"].append(
pconfig.extra_series.append(
Series(
path_in_cfg=("perfect_library",),
name="A perfect library where each read is unique",
Expand Down Expand Up @@ -267,7 +279,7 @@ def _parse_real_counts(self, sample_names):
return real_counts_total, real_counts_unique


def _parse_preseq_logs(f):
def _parse_preseq_logs(f) -> Tuple[Dict[float, float], bool]:
"""Go through log file looking for preseq output"""

lines = f["f"].splitlines()
Expand All @@ -281,16 +293,18 @@ def _parse_preseq_logs(f):
elif header.startswith("total_reads distinct_reads"):
pass
else:
log.debug(f"First line of preseq file {f['fn']} did not look right")
return None, None
raise ValueError(f"First line of preseq file does not look right: {header}")

data = dict()
data: Dict[float, float] = dict()
for line in lines:
s = line.split()
# Sometimes the Expected_distinct count drops to 0, not helpful
if float(s[1]) == 0 and float(s[0]) > 0:
continue
data[float(s[0])] = float(s[1])
x, y = float(s[0]), float(s[1])
if not np.isfinite(y):
continue
data[x] = y

return data, data_is_bases

Expand All @@ -302,28 +316,25 @@ def _modify_raw_data(sample_data, is_basepairs):
return {_modify_raw_val(x, is_basepairs): _modify_raw_val(y, is_basepairs) for x, y in sample_data.items()}


def _modify_raw_val(val, is_basepairs):
def _modify_raw_val(val: float, is_basepairs: bool) -> float:
"""Modify counts or base pairs according to `read_count_multiplier`
or `base_count_multiplier`.
"""
return float(val) * (config.base_count_multiplier if is_basepairs else config.read_count_multiplier)


def _counts_to_coverages(sample_data, counts_in_1x):
def _counts_to_coverages(sample_data: Dict[float, float], counts_in_1x: float) -> Dict[float, float]:
"""If the user specified read length and genome size in the config,
convert the raw counts/bases into the depth of coverage.
"""
if not counts_in_1x:
return {None: None}

return {_count_to_coverage(x, counts_in_1x): _count_to_coverage(y, counts_in_1x) for x, y in sample_data.items()}


def _count_to_coverage(val, counts_in_1x):
def _count_to_coverage(val: float, counts_in_1x: float) -> float:
return val / counts_in_1x


def _get_counts_in_1x(data_is_basepairs):
def _calc_count_in_1x(data_is_basepairs: bool) -> Optional[float]:
"""Read length and genome size from the config and calculate
the approximate number of counts (or base pairs) in 1x of depth
"""
Expand Down Expand Up @@ -352,19 +363,18 @@ def _get_counts_in_1x(data_is_basepairs):
return genome_size
elif read_length:
return genome_size / read_length
else:
return None
return None


def _prepare_labels(is_basepairs, max_y_cov, x_axis, y_axis):
def _prepare_labels(is_basepairs: bool, max_y: float, x_axis: str, y_axis: str) -> Tuple[str, str, str, str, str, str]:
cov_suffix = "x"

cov_lbl = None
cov_lbl: str = ""
if x_axis == "coverage" or y_axis == "coverage":
cov_precision = "2"
if max_y_cov > 30: # no need to be so precise when the depth numbers are high
if max_y > 30: # no need to be so precise when the depth numbers are high
cov_precision = "1"
if max_y_cov > 300: # when the depth are very high, decimal digits are excessive
if max_y > 300: # when the depth are very high, decimal digits are excessive
cov_precision = "0"
cov_lbl = "{value:,." + cov_precision + "f}x"

Expand Down
2 changes: 0 additions & 2 deletions multiqc/plots/plotly/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ def create_figure(
for series in self.lines:
xs = [x[0] for x in series.pairs]
ys = [x[1] for x in series.pairs]
if series.dash:
print(series)
params: Dict[str, Any] = {
"showlegend": series.showlegend,
"line": {
Expand Down

0 comments on commit 3834ac5

Please sign in to comment.