-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replacing unary ops with LookUpTable and Take op to improve performance #17214
Conversation
@jverma-quic and @quic-sanirudh can you please review the PR? |
@tvm-bot rerun |
33a86e1
to
0e42af8
Compare
if call_node.op != call_tir_op: | ||
return False | ||
global_var = call_node.args[0] | ||
return op_name in global_var.name_hint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use a better solution than looking at the global_var name to determine the type of op? Names might not be the most reliable way. Perhaps look for an attribute like operator_name
similar to what is used by AlterOpImpl
pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you that relying on the global_var is not the best way to identify the operators for this transformation. However, I don't really think that operator_name will be much better. The problem here is that we lower the graph to Relay first and then during translation to Relax, the operator knowledge is lost. @Lunderberg's suggestion would have worked very well if we could have imported the graph directly to Relax and then before legalizing it, we could have replaced R.tanh with R.take(..).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jverma-quic I agree that operator_name
is not the best solution either, and yeah @Lunderberg's solution is certainly the best, but yes, since we translate a relay graph to relax, we don't have that option. My only concern was that, atleast with an attribute like operator_name
, it has to be deliberately annotated if the pass has to work. If we use the function name instead, there could potentially be a case where someone could apply the pass on some fused op that has say something like add_tanh
and that might get replaced and result in correctness errors.
That being said, since this pass isn't added to any pipeline, and has to be consciously applied by someone, I think this might also be okay.
As for @Lunderberg's suggestion, we could perhaps support both types of ops (legalized and direct relax ops).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you on clarifying the use case, that that makes sense with why it would only have access to the legalized operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestions! @Lunderberg, @jverma-quic, @quic-sanirudh
I had few questions, this pass works for quantized data, does relax support quantized unary ops? I couldn't find quantized direct relax ops.
In that case we may not be able to do the optimization for direct relax ops yet.
Please let me know your suggestions on the same.
return self.builder_.get() | ||
|
||
def visit_call_(self, call_node: relax.Call) -> relax.Call: | ||
if call_node.args[1][0].struct_info.dtype == "uint8": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we verify whether the call_node is a relax.call_tir
op before accessing the args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we verify whether the call_node is a
relax.call_tir
op before accessing the args?
@quic-sanirudh: wouldn't it be guaranteed since we're only visiting the call nodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call nodes can be any relax call I think, and call_tir is just one type of call node right? What if there's a builtin call or some direct relax op call, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be safer I will add check for call_tir before invoking the pass. Thank you.
if op_replace(call_node): | ||
inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]] | ||
# LUT node creation | ||
LUT = hexagon_unary_ops.LUT_generation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this pass intended to be applied? If it can be moved to before LegalizeOps
, then that would make it easier to define the lookup table as a Relax expression (using R.arange(0,256,'uint8')
as all possible quantized values, passing it through the relax operations, then finishing with R.take(computed_table, inp)
). This would be simplified by the FoldConstantPass
to the same R.take(R.const(...), inp)
which is generated here, but wouldn't require explicit handling of each unary operation.
That would also allow the pattern-matching to be done based on the Relax operations themselves, rather than their lowered names.
… during compile time and take op is used to access the values at runtime
…finition. Added check if the call node is of call tir type
0e42af8
to
71fd080
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks @sdalvi-quic
@Lunderberg As @sdalvi-quic has mentioned above in her comment, since this pass is to replace quantized (qnn) ops with a lookup table, and since relax does not have quantized ops yet (AFAIK, pls correct me if I'm wrong about that), I think this version is probably the best we'll be able to do for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Agreed, since the quantized ops have already been converted to Relax by the point this pass is applied, the high-level Relax information would no longer be present. |
Thanks @sdalvi-quic @Lunderberg @jverma-quic. |
Motivation : Unary ops calculate the value during runtime which is costly. Since we are handling quantized data hence, it is better to pre compute the value during compile time (for 0 -256 value) and store in a Look Up Table which is accessed during the runtime.
Implementation : This pass generates the Look Up Table for each unary op during compile time. It iterates over the original graph and replaces the unary ops call node, (example tanh, sqrt) to a call node that accesses a Look Up Table corresponding to the unary op, and return the value.
Testing : The test_take checks for both value correctness and structural correctness of the output generated after the pass.
Reference: #15614