@@ -432,6 +432,38 @@ def _assign_ones_like_indice(self, node: Node, node_idx: int):
432
432
"""
433
433
self ._assign_all_indice (node , node_idx )
434
434
435
+ def _assign_cat_indice (self , node : Node , node_idx : int ):
436
+ """
437
+ Assign indice for cat op.
438
+
439
+ Args:
440
+ node (node)
441
+ node_idx (int)
442
+ """
443
+ nodes_in = flat_list (node .args [0 ])
444
+ self ._assign_indice_as_input (node , node_idx , input_node = nodes_in [0 ])
445
+ for n in nodes_in [1 :]:
446
+ self ._mark_computation_from_node (n , node )
447
+ cat_dim = node .kwargs ["dim" ]
448
+ self ._del_dim (node_idx , cat_dim )
449
+ self ._add_dim (node_idx , cat_dim )
450
+
451
+ def _assign_sum_indice (self , node : Node , node_idx : int ):
452
+ """
453
+ Assign indice for sum op.
454
+
455
+ Args:
456
+ node (node)
457
+ node_idx (int)
458
+ """
459
+ nodes_in = flat_list (node .args [0 ])
460
+ self ._add_dim (node_idx , 0 )
461
+ self ._assign_indice_as_input (node , node_idx , input_node = nodes_in [0 ])
462
+ for n in nodes_in [1 :]:
463
+ self ._mark_computation_from_node (n , node )
464
+ cat_dim = node .kwargs ["dim" ]
465
+ self ._del_dim (node_idx , cat_dim )
466
+
435
467
def _assign_getitem_indice (self , node : Node , node_idx : int ):
436
468
"""
437
469
Assign indice for getitem.
@@ -442,7 +474,16 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
442
474
node_idx (int)
443
475
"""
444
476
node_args = flat_list (node .args [1 :])
445
- if not any (i == str (node_arg ) for i in ["None" , "Ellipsis" ] for node_arg in node_args ):
477
+ flag = False
478
+ for node_arg in node_args :
479
+ node_arg_str = str (node_arg )
480
+ if any (i == node_arg_str for i in ["None" , "Ellipsis" ]):
481
+ flag = True
482
+ break
483
+ if "slice" in node_arg_str :
484
+ flag = True
485
+ break
486
+ if flag == False :
446
487
return
447
488
448
489
# node args should be like [Ellipsis, slice(start, step, end), None]
@@ -461,8 +502,11 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
461
502
shape_gap = len (node_shape ) - len (node_args ) + 1
462
503
origin_idx_count += shape_gap
463
504
new_idx_count += shape_gap
464
- # slice(None, None, None) means all indexes, doesn't support other slice
465
- elif "slice(None, None, None)" == node_arg_str :
505
+ # slice(None, None, None) means all indexes
506
+ elif "slice" in node_arg_str :
507
+ if "slice(None, None, None)" != node_arg_str :
508
+ self ._del_dim (node_idx , new_idx_count )
509
+ self ._add_dim (node_idx , new_idx_count )
466
510
origin_idx_count += 1
467
511
new_idx_count += 1
468
512
# None means a new dim
@@ -565,7 +609,7 @@ def trace_indice(self):
565
609
self ._assign_view_reshape_indice (node , idx )
566
610
elif "unsqueeze" in node .name :
567
611
self ._assign_unsqueeze_indice (node , idx )
568
- elif any (i in node .name for i in ["to" , "contiguous" ]):
612
+ elif any (i in node .name for i in ["to" , "contiguous" , "clone" ]):
569
613
self ._assgin_no_change_indice (node , idx )
570
614
elif "new_ones" in node .name :
571
615
self ._assign_ones_like_indice (node , idx )
@@ -574,6 +618,8 @@ def trace_indice(self):
574
618
elif node .op == "call_function" :
575
619
if "linear" in node .name :
576
620
self ._assign_linear_indice (node , idx )
621
+ elif "cat" in node .name :
622
+ self ._assign_cat_indice (node , idx )
577
623
elif "matmul" in node .name :
578
624
self ._assign_matmul_indice (node , idx )
579
625
elif "softmax" in node .name :
@@ -586,6 +632,8 @@ def trace_indice(self):
586
632
self ._assign_dropout_indice (node , idx )
587
633
elif "einsum" in node .name :
588
634
self ._assign_einsum_indice (node , idx )
635
+ elif "sum" in node .name :
636
+ self ._assign_sum_indice (node , idx )
589
637
elif "layer_norm" in node .name :
590
638
self ._assign_layernorm_indice (node , idx )
591
639
elif "getitem" in node .name :
0 commit comments