Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unused node check for compile_dag_node #49382

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
fix style from suggestions
Signed-off-by: Lydia <wenqix3@uw.edu>
  • Loading branch information
LydiaXwQ committed Dec 20, 2024
commit 39da7480e61ff2c1c9345f8d14bebcd2d073e232
18 changes: 10 additions & 8 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,9 +929,9 @@ def _add_node(self, node: "ray.dag.DAGNode") -> None:
self.dag_node_to_idx[node] = idx
self.counter += 1

# Helper function to handle_unused_attributes.
def find_unused_input_attributes(self, output_node, input_attributes) -> Set[str]:
def find_unused_input_attributes(self, output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]) -> Set[str]:
"""
This is the helper function to handle_unused_attributes.
Traverse the DAG backwards from the output node to find unused attributes.
Args:
output_node: The starting node for the traversal.
Expand All @@ -942,12 +942,12 @@ def find_unused_input_attributes(self, output_node, input_attributes) -> Set[str
"""
from ray.dag import InputAttributeNode

used_attributes = set()
used_attributes: Set[str] = set()
# Keep track of visited nodes during backtracking.
visited_nodes = set()
visited_nodes: Set["ray.dag.MultiOutputNode"] = set()

# Traverse backwards from the output node to find all used inputs.
def traverse(node):
def traverse(node: "ray.dag.MultiOutputNode"):
if node in visited_nodes:
return
visited_nodes.add(node)
Expand Down Expand Up @@ -986,7 +986,7 @@ def _preprocess(self) -> None:
nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set()
collective_ops: Set[_CollectiveOperation] = set()

input_attributes = set()
input_attributes: Set[str] = set()
# Find the input node and input attribute nodes in the DAG.
for idx, task in self.idx_to_task.items():
if isinstance(task.dag_node, InputNode):
Expand Down Expand Up @@ -1187,9 +1187,11 @@ def _preprocess(self) -> None:
unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"

raise ValueError(
f"DAG expects input: {input_attributes_str}, "
"Compiled Graph expects input to be accessed "
f"using all of attributes {input_attributes_str}, "
f"but {unused_attributes_str} {unused_phrase}. "
"Ensure all accessed inputs from InputNode are connected to the output."
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output."
)

# Collect all leaf nodes.
Expand Down
15 changes: 9 additions & 6 deletions python/ray/dag/tests/test_input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,16 +394,19 @@ def combine(self, a, b):
actor = Actor.remote()

with ray.dag.InputNode() as dag_input:
input1, input2, input3 = dag_input[0], dag_input[1], dag_input[2]
_ = actor.f.bind(input2)
dag = actor.combine.bind(input1, input3)
input0, input1, input2 = dag_input[0], dag_input[1], dag_input[2]
_ = actor.f.bind(input1)
dag = actor.combine.bind(input0, input2)

with pytest.raises(
ValueError,
match="DAG expects input: 0, 1, 2, but 1 is unused. "
"Ensure all accessed inputs from InputNode are connected to the output.",
match="Compiled Graph expects input to be accessed "
"using all of attributes 0, 1, 2, "
"but 1 is unused. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output."
):
_ = dag.experimental_compile()
dag.experimental_compile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this test to test_accelerated_dag.py? test_input_node.py is for general DAG, not necessarily compiled graph



if __name__ == "__main__":
Expand Down