Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] support meta tracing for aten level computation graphs like functorch. #1536

Merged
merged 5 commits into from
Sep 5, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Sep 2, 2022

What's new?

In this PR, I moved the _meta_registration.py in previous PR #1515 and #1530 to the root directory of Colossal-AI so that all meta atens can be registered in advance.
With the meta tracing, we can easily get the computation graph of any model just as what functorch does.

>>> from colossalai.fx import meta_trace
>>> import torch
>>> import torchvision.models as tm
>>> model = tm.alexnet()
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))

You will get a complete graph.

opcode         name                                        target                                         args                                                                                                                                               kwargs
-------------  ------------------------------------------  ---------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------------------------------------------------------
placeholder    input_1                                     placeholder                                    (,)                                                                                                                                                {}
placeholder    weight                                      placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_1                                    placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_default                         aten.convolution.default                       (input_1, weight, weight_1, [4, 4], [2, 2], [1, 1], False, [0, 0], 1)                                                                              {}
call_function  relu__default                               aten.relu_.default                             (convolution_default,)                                                                                                                             {}
call_function  detach_default                              aten.detach.default                            (convolution_default,)                                                                                                                             {}
call_function  max_pool2d_with_indices_default             aten.max_pool2d_with_indices.default           (convolution_default, [3, 3], [2, 2])                                                                                                              {}
placeholder    weight_2                                    placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_3                                    placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_default_1                       aten.convolution.default                       (max_pool2d_with_indices_default, weight_2, weight_3, [1, 1], [2, 2], [1, 1], False, [0, 0], 1)                                                    {}
call_function  relu__default_1                             aten.relu_.default                             (convolution_default_1,)                                                                                                                           {}
call_function  detach_default_1                            aten.detach.default                            (convolution_default_1,)                                                                                                                           {}
call_function  max_pool2d_with_indices_default_1           aten.max_pool2d_with_indices.default           (convolution_default_1, [3, 3], [2, 2])                                                                                                            {}
placeholder    weight_4                                    placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_5                                    placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_default_2                       aten.convolution.default                       (max_pool2d_with_indices_default_1, weight_4, weight_5, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)                                                  {}
call_function  relu__default_2                             aten.relu_.default                             (convolution_default_2,)                                                                                                                           {}
call_function  detach_default_2                            aten.detach.default                            (convolution_default_2,)                                                                                                                           {}
placeholder    weight_6                                    placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_7                                    placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_default_3                       aten.convolution.default                       (convolution_default_2, weight_6, weight_7, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)                                                              {}
call_function  relu__default_3                             aten.relu_.default                             (convolution_default_3,)                                                                                                                           {}
call_function  detach_default_3                            aten.detach.default                            (convolution_default_3,)                                                                                                                           {}
placeholder    weight_8                                    placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_9                                    placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_default_4                       aten.convolution.default                       (convolution_default_3, weight_8, weight_9, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)                                                              {}
call_function  relu__default_4                             aten.relu_.default                             (convolution_default_4,)                                                                                                                           {}
call_function  detach_default_4                            aten.detach.default                            (convolution_default_4,)                                                                                                                           {}
call_function  max_pool2d_with_indices_default_2           aten.max_pool2d_with_indices.default           (convolution_default_4, [3, 3], [2, 2])                                                                                                            {}
call_function  _adaptive_avg_pool2d_default                aten._adaptive_avg_pool2d.default              (max_pool2d_with_indices_default_2, [6, 6])                                                                                                        {}
call_function  _reshape_alias_default                      aten._reshape_alias.default                    (_adaptive_avg_pool2d_default, [1000, 9216], [9216, 1])                                                                                            {}
call_function  empty_like_default                          aten.empty_like.default                        (_reshape_alias_default,)                                                                                                                          {'memory_format': torch.contiguous_format}
call_function  bernoulli__float                            aten.bernoulli_.float                          (empty_like_default,)                                                                                                                              {}
call_function  div__scalar                                 aten.div_.Scalar                               (empty_like_default, 0.5)                                                                                                                          {}
call_function  mul_tensor                                  aten.mul.Tensor                                (_reshape_alias_default, empty_like_default)                                                                                                       {}
placeholder    weight_10                                   placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_11                                   placeholder                                    (,)                                                                                                                                                {}
call_function  addmm_default                               aten.addmm.default                             (weight_10, mul_tensor, weight_11)                                                                                                                 {}
call_function  relu__default_5                             aten.relu_.default                             (addmm_default,)                                                                                                                                   {}
call_function  detach_default_5                            aten.detach.default                            (addmm_default,)                                                                                                                                   {}
call_function  empty_like_default_1                        aten.empty_like.default                        (addmm_default,)                                                                                                                                   {'memory_format': torch.contiguous_format}
call_function  bernoulli__float_1                          aten.bernoulli_.float                          (empty_like_default_1,)                                                                                                                            {}
call_function  div__scalar_1                               aten.div_.Scalar                               (empty_like_default_1, 0.5)                                                                                                                        {}
call_function  mul_tensor_1                                aten.mul.Tensor                                (addmm_default, empty_like_default_1)                                                                                                              {}
placeholder    weight_12                                   placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_13                                   placeholder                                    (,)                                                                                                                                                {}
call_function  addmm_default_1                             aten.addmm.default                             (weight_12, mul_tensor_1, weight_13)                                                                                                               {}
call_function  relu__default_6                             aten.relu_.default                             (addmm_default_1,)                                                                                                                                 {}
call_function  detach_default_6                            aten.detach.default                            (addmm_default_1,)                                                                                                                                 {}
placeholder    weight_14                                   placeholder                                    (,)                                                                                                                                                {}
placeholder    weight_15                                   placeholder                                    (,)                                                                                                                                                {}
call_function  addmm_default_2                             aten.addmm.default                             (weight_14, addmm_default_1, weight_15)                                                                                                            {}
call_function  sum_default                                 aten.sum.default                               (addmm_default_2,)                                                                                                                                 {}
call_function  ones_like_default                           aten.ones_like.default                         (sum_default,)                                                                                                                                     {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cpu'), 'pin_memory': False, 'memory_format': torch.preserve_format}
call_function  expand_default                              aten.expand.default                            (ones_like_default, [1000, 1000])                                                                                                                  {}
placeholder    weight_16                                   placeholder                                    (,)                                                                                                                                                {}
call_function  mm_default                                  aten.mm.default                                (expand_default, weight_16)                                                                                                                        {}
call_function  t_default                                   aten.t.default                                 (expand_default,)                                                                                                                                  {}
call_function  mm_default_1                                aten.mm.default                                (t_default, addmm_default_1)                                                                                                                       {}
call_function  t_default_1                                 aten.t.default                                 (mm_default_1,)                                                                                                                                    {}
call_function  sum_dim_int_list                            aten.sum.dim_IntList                           (expand_default, [0], True)                                                                                                                        {}
call_function  view_default                                aten.view.default                              (sum_dim_int_list, [1000])                                                                                                                         {}
call_function  detach_default_7                            aten.detach.default                            (view_default,)                                                                                                                                    {}
call_function  detach_default_8                            aten.detach.default                            (detach_default_7,)                                                                                                                                {}
call_function  t_default_2                                 aten.t.default                                 (t_default_1,)                                                                                                                                     {}
call_function  detach_default_9                            aten.detach.default                            (t_default_2,)                                                                                                                                     {}
call_function  detach_default_10                           aten.detach.default                            (detach_default_9,)                                                                                                                                {}
call_function  detach_default_11                           aten.detach.default                            (detach_default_6,)                                                                                                                                {}
call_function  threshold_backward_default                  aten.threshold_backward.default                (mm_default, detach_default_11, 0)                                                                                                                 {}
placeholder    weight_17                                   placeholder                                    (,)                                                                                                                                                {}
call_function  mm_default_2                                aten.mm.default                                (threshold_backward_default, weight_17)                                                                                                            {}
call_function  t_default_3                                 aten.t.default                                 (threshold_backward_default,)                                                                                                                      {}
call_function  mm_default_3                                aten.mm.default                                (t_default_3, mul_tensor_1)                                                                                                                        {}
call_function  t_default_4                                 aten.t.default                                 (mm_default_3,)                                                                                                                                    {}
call_function  sum_dim_int_list_1                          aten.sum.dim_IntList                           (threshold_backward_default, [0], True)                                                                                                            {}
call_function  view_default_1                              aten.view.default                              (sum_dim_int_list_1, [4096])                                                                                                                       {}
call_function  detach_default_12                           aten.detach.default                            (view_default_1,)                                                                                                                                  {}
call_function  detach_default_13                           aten.detach.default                            (detach_default_12,)                                                                                                                               {}
call_function  t_default_5                                 aten.t.default                                 (t_default_4,)                                                                                                                                     {}
call_function  detach_default_14                           aten.detach.default                            (t_default_5,)                                                                                                                                     {}
call_function  detach_default_15                           aten.detach.default                            (detach_default_14,)                                                                                                                               {}
call_function  mul_tensor_2                                aten.mul.Tensor                                (mm_default_2, empty_like_default_1)                                                                                                               {}
call_function  detach_default_16                           aten.detach.default                            (detach_default_5,)                                                                                                                                {}
call_function  threshold_backward_default_1                aten.threshold_backward.default                (mul_tensor_2, detach_default_16, 0)                                                                                                               {}
placeholder    weight_18                                   placeholder                                    (,)                                                                                                                                                {}
call_function  mm_default_4                                aten.mm.default                                (threshold_backward_default_1, weight_18)                                                                                                          {}
call_function  t_default_6                                 aten.t.default                                 (threshold_backward_default_1,)                                                                                                                    {}
call_function  mm_default_5                                aten.mm.default                                (t_default_6, mul_tensor)                                                                                                                          {}
call_function  t_default_7                                 aten.t.default                                 (mm_default_5,)                                                                                                                                    {}
call_function  sum_dim_int_list_2                          aten.sum.dim_IntList                           (threshold_backward_default_1, [0], True)                                                                                                          {}
call_function  view_default_2                              aten.view.default                              (sum_dim_int_list_2, [4096])                                                                                                                       {}
call_function  detach_default_17                           aten.detach.default                            (view_default_2,)                                                                                                                                  {}
call_function  detach_default_18                           aten.detach.default                            (detach_default_17,)                                                                                                                               {}
call_function  t_default_8                                 aten.t.default                                 (t_default_7,)                                                                                                                                     {}
call_function  detach_default_19                           aten.detach.default                            (t_default_8,)                                                                                                                                     {}
call_function  detach_default_20                           aten.detach.default                            (detach_default_19,)                                                                                                                               {}
call_function  mul_tensor_3                                aten.mul.Tensor                                (mm_default_4, empty_like_default)                                                                                                                 {}
call_function  _reshape_alias_default_1                    aten._reshape_alias.default                    (mul_tensor_3, [1000, 256, 6, 6], [9216, 36, 6, 1])                                                                                                {}
call_function  _adaptive_avg_pool2d_backward_default       aten._adaptive_avg_pool2d_backward.default     (_reshape_alias_default_1, max_pool2d_with_indices_default_2)                                                                                      {}
call_function  max_pool2d_with_indices_backward_default    aten.max_pool2d_with_indices_backward.default  (_adaptive_avg_pool2d_backward_default, convolution_default_4, [3, 3], [2, 2], [0, 0], [1, 1], False, max_pool2d_with_indices_default_2)           {}
call_function  detach_default_21                           aten.detach.default                            (detach_default_4,)                                                                                                                                {}
call_function  threshold_backward_default_2                aten.threshold_backward.default                (max_pool2d_with_indices_backward_default, detach_default_21, 0)                                                                                   {}
placeholder    weight_19                                   placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_backward_default                aten.convolution_backward.default              (threshold_backward_default_2, convolution_default_3, weight_19, [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True])              {}
call_function  detach_default_22                           aten.detach.default                            (convolution_backward_default,)                                                                                                                    {}
call_function  detach_default_23                           aten.detach.default                            (detach_default_22,)                                                                                                                               {}
call_function  detach_default_24                           aten.detach.default                            (convolution_backward_default,)                                                                                                                    {}
call_function  detach_default_25                           aten.detach.default                            (detach_default_24,)                                                                                                                               {}
call_function  detach_default_26                           aten.detach.default                            (detach_default_3,)                                                                                                                                {}
call_function  threshold_backward_default_3                aten.threshold_backward.default                (convolution_backward_default, detach_default_26, 0)                                                                                               {}
placeholder    weight_20                                   placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_backward_default_1              aten.convolution_backward.default              (threshold_backward_default_3, convolution_default_2, weight_20, [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True])              {}
call_function  detach_default_27                           aten.detach.default                            (convolution_backward_default_1,)                                                                                                                  {}
call_function  detach_default_28                           aten.detach.default                            (detach_default_27,)                                                                                                                               {}
call_function  detach_default_29                           aten.detach.default                            (convolution_backward_default_1,)                                                                                                                  {}
call_function  detach_default_30                           aten.detach.default                            (detach_default_29,)                                                                                                                               {}
call_function  detach_default_31                           aten.detach.default                            (detach_default_2,)                                                                                                                                {}
call_function  threshold_backward_default_4                aten.threshold_backward.default                (convolution_backward_default_1, detach_default_31, 0)                                                                                             {}
placeholder    weight_21                                   placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_backward_default_2              aten.convolution_backward.default              (threshold_backward_default_4, max_pool2d_with_indices_default_1, weight_21, [384], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True])  {}
call_function  detach_default_32                           aten.detach.default                            (convolution_backward_default_2,)                                                                                                                  {}
call_function  detach_default_33                           aten.detach.default                            (detach_default_32,)                                                                                                                               {}
call_function  detach_default_34                           aten.detach.default                            (convolution_backward_default_2,)                                                                                                                  {}
call_function  detach_default_35                           aten.detach.default                            (detach_default_34,)                                                                                                                               {}
call_function  max_pool2d_with_indices_backward_default_1  aten.max_pool2d_with_indices_backward.default  (convolution_backward_default_2, convolution_default_1, [3, 3], [2, 2], [0, 0], [1, 1], False, max_pool2d_with_indices_default_1)                  {}
call_function  detach_default_36                           aten.detach.default                            (detach_default_1,)                                                                                                                                {}
call_function  threshold_backward_default_5                aten.threshold_backward.default                (max_pool2d_with_indices_backward_default_1, detach_default_36, 0)                                                                                 {}
placeholder    weight_22                                   placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_backward_default_3              aten.convolution_backward.default              (threshold_backward_default_5, max_pool2d_with_indices_default, weight_22, [192], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True])    {}
call_function  detach_default_37                           aten.detach.default                            (convolution_backward_default_3,)                                                                                                                  {}
call_function  detach_default_38                           aten.detach.default                            (detach_default_37,)                                                                                                                               {}
call_function  detach_default_39                           aten.detach.default                            (convolution_backward_default_3,)                                                                                                                  {}
call_function  detach_default_40                           aten.detach.default                            (detach_default_39,)                                                                                                                               {}
call_function  max_pool2d_with_indices_backward_default_2  aten.max_pool2d_with_indices_backward.default  (convolution_backward_default_3, convolution_default, [3, 3], [2, 2], [0, 0], [1, 1], False, max_pool2d_with_indices_default)                      {}
call_function  detach_default_41                           aten.detach.default                            (detach_default,)                                                                                                                                  {}
call_function  threshold_backward_default_6                aten.threshold_backward.default                (max_pool2d_with_indices_backward_default_2, detach_default_41, 0)                                                                                 {}
placeholder    weight_23                                   placeholder                                    (,)                                                                                                                                                {}
call_function  convolution_backward_default_4              aten.convolution_backward.default              (threshold_backward_default_6, input_1, weight_23, [64], [4, 4], [2, 2], [1, 1], False, [0, 0], 1, [False, True, True])                            {}
call_function  detach_default_42                           aten.detach.default                            (convolution_backward_default_4,)                                                                                                                  {}
call_function  detach_default_43                           aten.detach.default                            (detach_default_42,)                                                                                                                               {}
call_function  detach_default_44                           aten.detach.default                            (convolution_backward_default_4,)                                                                                                                  {}
call_function  detach_default_45                           aten.detach.default                            (detach_default_44,)                                                                                                                               {}

@YuliangLiu0306 YuliangLiu0306 merged commit 7012960 into hpcaitech:main Sep 5, 2022
@super-dainiu super-dainiu deleted the feature/meta_trace branch September 6, 2022 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants