Skip to content

Commit

Permalink
fix: cpd floating point comparison issue (#172)
Browse files Browse the repository at this point in the history
* fix: cpd floating point comparison issue

* fix: imports ordering
  • Loading branch information
ivandkh authored Sep 13, 2022
1 parent 4262609 commit c42d0e2
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion causalnex/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
import copy
import inspect
import math
import re
import types
from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -205,7 +206,7 @@ def _do(self, observation: str, state: Dict[Hashable, float]):
Raises:
ValueError: if states do not match original states of the node, or probabilities do not sum to 1.
"""
if sum(state.values()) != 1.0:
if not math.isclose(sum(state.values()), 1.0):
raise ValueError("The cpd for the provided observation must sum to 1")

if max(state.values()) > 1.0 or min(state.values()) < 0:
Expand Down

0 comments on commit c42d0e2

Please sign in to comment.