Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing unary ops with LookUpTable and Take op to improve performance #17214

Merged
merged 4 commits into from
Aug 9, 2024

Conversation

sdalvi-quic
Copy link
Contributor

@sdalvi-quic sdalvi-quic commented Jul 30, 2024

Motivation : Unary ops calculate the value during runtime which is costly. Since we are handling quantized data hence, it is better to pre compute the value during compile time (for 0 -256 value) and store in a Look Up Table which is accessed during the runtime.

Implementation : This pass generates the Look Up Table for each unary op during compile time. It iterates over the original graph and replaces the unary ops call node, (example tanh, sqrt) to a call node that accesses a Look Up Table corresponding to the unary op, and return the value.

Testing : The test_take checks for both value correctness and structural correctness of the output generated after the pass.

Reference: #15614

@sdalvi-quic
Copy link
Contributor Author

sdalvi-quic commented Jul 30, 2024

@jverma-quic and @quic-sanirudh can you please review the PR?
I implemented this sometime back but, could not merge it then, and the PR was closed.
Reference : #15614.

@sdalvi-quic sdalvi-quic changed the title Replacing unary ops with LookUpTable and Take op to improve performance. Replacing unary ops with LookUpTable and Take op to improve performance Jul 30, 2024
@sdalvi-quic
Copy link
Contributor Author

@tvm-bot rerun

if call_node.op != call_tir_op:
return False
global_var = call_node.args[0]
return op_name in global_var.name_hint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a better solution than looking at the global_var name to determine the type of op? Names might not be the most reliable way. Perhaps look for an attribute like operator_name similar to what is used by AlterOpImpl pass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that relying on the global_var is not the best way to identify the operators for this transformation. However, I don't really think that operator_name will be much better. The problem here is that we lower the graph to Relay first and then during translation to Relax, the operator knowledge is lost. @Lunderberg's suggestion would have worked very well if we could have imported the graph directly to Relax and then before legalizing it, we could have replaced R.tanh with R.take(..).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jverma-quic I agree that operator_name is not the best solution either, and yeah @Lunderberg's solution is certainly the best, but yes, since we translate a relay graph to relax, we don't have that option. My only concern was that, atleast with an attribute like operator_name, it has to be deliberately annotated if the pass has to work. If we use the function name instead, there could potentially be a case where someone could apply the pass on some fused op that has say something like add_tanh and that might get replaced and result in correctness errors.

That being said, since this pass isn't added to any pipeline, and has to be consciously applied by someone, I think this might also be okay.

As for @Lunderberg's suggestion, we could perhaps support both types of ops (legalized and direct relax ops).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you on clarifying the use case, that that makes sense with why it would only have access to the legalized operators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestions! @Lunderberg, @jverma-quic, @quic-sanirudh

I had few questions, this pass works for quantized data, does relax support quantized unary ops? I couldn't find quantized direct relax ops.
In that case we may not be able to do the optimization for direct relax ops yet.

Please let me know your suggestions on the same.

return self.builder_.get()

def visit_call_(self, call_node: relax.Call) -> relax.Call:
if call_node.args[1][0].struct_info.dtype == "uint8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we verify whether the call_node is a relax.call_tir op before accessing the args?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we verify whether the call_node is a relax.call_tir op before accessing the args?

@quic-sanirudh: wouldn't it be guaranteed since we're only visiting the call nodes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call nodes can be any relax call I think, and call_tir is just one type of call node right? What if there's a builtin call or some direct relax op call, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be safer I will add check for call_tir before invoking the pass. Thank you.

if op_replace(call_node):
inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]]
# LUT node creation
LUT = hexagon_unary_ops.LUT_generation(
Copy link
Contributor

@Lunderberg Lunderberg Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this pass intended to be applied? If it can be moved to before LegalizeOps, then that would make it easier to define the lookup table as a Relax expression (using R.arange(0,256,'uint8') as all possible quantized values, passing it through the relax operations, then finishing with R.take(computed_table, inp)). This would be simplified by the FoldConstantPass to the same R.take(R.const(...), inp) which is generated here, but wouldn't require explicit handling of each unary operation.

That would also allow the pattern-matching to be done based on the Relax operations themselves, rather than their lowered names.

Copy link
Contributor

@quic-sanirudh quic-sanirudh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks @sdalvi-quic

@quic-sanirudh
Copy link
Contributor

@Lunderberg As @sdalvi-quic has mentioned above in her comment, since this pass is to replace quantized (qnn) ops with a lookup table, and since relax does not have quantized ops yet (AFAIK, pls correct me if I'm wrong about that), I think this version is probably the best we'll be able to do for now.

Copy link
Contributor

@jverma-quic jverma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Lunderberg
Copy link
Contributor

Agreed, since the quantized ops have already been converted to Relax by the point this pass is applied, the high-level Relax information would no longer be present.

@quic-sanirudh quic-sanirudh merged commit 7739171 into apache:main Aug 9, 2024
20 checks passed
@quic-sanirudh
Copy link
Contributor

Thanks @sdalvi-quic @Lunderberg @jverma-quic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants