Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unused node check for compile_dag_node #49382

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
debugging for data timeout
Signed-off-by: Lydia <wenqix3@uw.edu>
  • Loading branch information
LydiaXwQ committed Dec 20, 2024
commit ad717e566e2ecda18612bcf43147375c83cf56e5
112 changes: 56 additions & 56 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,41 +929,41 @@ def _add_node(self, node: "ray.dag.DAGNode") -> None:
self.dag_node_to_idx[node] = idx
self.counter += 1

def find_unused_input_attributes(
self, output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
) -> Set[str]:
"""
This is the helper function to handle_unused_attributes.
Traverse the DAG backwards from the output node to find unused attributes.
Args:
output_node: The starting node for the traversal.
input_attributes: A set of attributes accessed by the InputNode.
Returns:
A set:
- unused_attributes: A set of attributes that are unused.
"""
from ray.dag import InputAttributeNode

used_attributes: Set[str] = set()
# Keep track of visited nodes during backtracking.
visited_nodes: Set["ray.dag.MultiOutputNode"] = set()

# Traverse backwards from the output node to find all used inputs.
def traverse(node: "ray.dag.MultiOutputNode"):
if node in visited_nodes:
return
visited_nodes.add(node)

if isinstance(node, InputAttributeNode):
used_attributes.add(node.key)

for upstream_node in node._upstream_nodes:
traverse(upstream_node)

traverse(output_node)

unused_attributes = input_attributes - used_attributes
return unused_attributes
# def find_unused_input_attributes(
# self, output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
# ) -> Set[str]:
# """
# This is the helper function to handle_unused_attributes.
# Traverse the DAG backwards from the output node to find unused attributes.
# Args:
# output_node: The starting node for the traversal.
# input_attributes: A set of attributes accessed by the InputNode.
# Returns:
# A set:
# - unused_attributes: A set of attributes that are unused.
# """
# from ray.dag import InputAttributeNode

# used_attributes: Set[str] = set()
# # Keep track of visited nodes during backtracking.
# visited_nodes: Set["ray.dag.MultiOutputNode"] = set()

# # Traverse backwards from the output node to find all used inputs.
# def traverse(node: "ray.dag.MultiOutputNode"):
# if node in visited_nodes:
# return
# visited_nodes.add(node)

# if isinstance(node, InputAttributeNode):
# used_attributes.add(node.key)

# for upstream_node in node._upstream_nodes:
# traverse(upstream_node)

# traverse(output_node)

# unused_attributes = input_attributes - used_attributes
# return unused_attributes

def _preprocess(self) -> None:
"""Before compiling, preprocess the DAG to build an index from task to
Expand Down Expand Up @@ -994,10 +994,10 @@ def _preprocess(self) -> None:
if isinstance(task.dag_node, InputNode):
assert self.input_task_idx is None, "More than one InputNode found"
self.input_task_idx = idx
# handle_unused_attributes:
# Save input attributes in a set.
input_node = task.dag_node
input_attributes.update(input_node.input_attribute_nodes.keys())
# # handle_unused_attributes:
# # Save input attributes in a set.
# input_node = task.dag_node
# input_attributes.update(input_node.input_attribute_nodes.keys())
elif isinstance(task.dag_node, InputAttributeNode):
self.input_attr_task_idxs.append(idx)

Expand Down Expand Up @@ -1177,23 +1177,23 @@ def _preprocess(self) -> None:
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

# handle_unused_attributes:
unused_attributes = self.find_unused_input_attributes(
output_node, input_attributes
)

if unused_attributes:
unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
input_attributes_str = ", ".join(str(key) for key in input_attributes)
unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"

raise ValueError(
"Compiled Graph expects input to be accessed "
f"using all of attributes {input_attributes_str}, "
f"but {unused_attributes_str} {unused_phrase}. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output."
)
# # handle_unused_attributes:
# unused_attributes = self.find_unused_input_attributes(
# output_node, input_attributes
# )

# if unused_attributes:
# unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
# input_attributes_str = ", ".join(str(key) for key in input_attributes)
# unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"

# raise ValueError(
# "Compiled Graph expects input to be accessed "
# f"using all of attributes {input_attributes_str}, "
# f"but {unused_attributes_str} {unused_phrase}. "
# "Ensure all input attributes are used and contribute "
# "to the computation of the Compiled Graph output."
# )

# Collect all leaf nodes.
leaf_nodes: DAGNode = []
Expand Down
56 changes: 28 additions & 28 deletions python/ray/dag/tests/test_input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,34 +379,34 @@ def _apply_recursive_with_counter(self, fn):
DAGNode.apply_recursive = original_apply_recursive


def test_missing_input_node():
@ray.remote
class Actor:
def __init__(self):
pass

def f(self, input):
return input

def combine(self, a, b):
return a + b

actor = Actor.remote()

with ray.dag.InputNode() as dag_input:
input0, input1, input2 = dag_input[0], dag_input[1], dag_input[2]
_ = actor.f.bind(input1)
dag = actor.combine.bind(input0, input2)

with pytest.raises(
ValueError,
match="Compiled Graph expects input to be accessed "
"using all of attributes 0, 1, 2, "
"but 1 is unused. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output.",
):
dag.experimental_compile()
# def test_missing_input_node():
# @ray.remote
# class Actor:
# def __init__(self):
# pass

# def f(self, input):
# return input

# def combine(self, a, b):
# return a + b

# actor = Actor.remote()

# with ray.dag.InputNode() as dag_input:
# input0, input1, input2 = dag_input[0], dag_input[1], dag_input[2]
# _ = actor.f.bind(input1)
# dag = actor.combine.bind(input0, input2)

# with pytest.raises(
# ValueError,
# match="Compiled Graph expects input to be accessed "
# "using all of attributes 0, 1, 2, "
# "but 1 is unused. "
# "Ensure all input attributes are used and contribute "
# "to the computation of the Compiled Graph output.",
# ):
# dag.experimental_compile()


if __name__ == "__main__":
Expand Down
Loading