Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript to CoreML conversion skips for loops #1995

Open
the-neural-networker opened this issue Sep 27, 2023 · 5 comments
Open

Torchscript to CoreML conversion skips for loops #1995

the-neural-networker opened this issue Sep 27, 2023 · 5 comments
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (not traced)

Comments

@the-neural-networker
Copy link

🐞Description

When torchscript code contains for loops and is converted to coreML, the for loops are skipped and the result before the for loop is returned.

To Reproduce

import numpy as np
import torch
import torch.nn as nn
import coremltools as ct

@torch.jit.script
def experiment(val: torch.Tensor):
    val = int(val.item())
    result = torch.zeros(val)
    for i in range(val):
        result[i] = i
    return result

class Experiment(nn.Module):
    def __init__(self):
        super(Experiment, self).__init__()

    def forward(self, x):
        return experiment(x)

exp = Experiment()

# Use a tensor as input, not an integer
input_tensor = torch.tensor(100)
output = exp(input_tensor)
print(output)

traced_exp = torch.jit.script(exp, input_tensor)
traced_exp.eval()

output_traced = traced_exp(input_tensor)
print(output)

# Specify the input type as ct.TensorType(name="x", shape=(1,))
coreml_exp = ct.convert(
    traced_exp,
    source="pytorch",
    inputs=[ct.TensorType(name="x", shape=(1,))],
    convert_to="mlprogram"
)

# Create an input dictionary with the necessary input data
input_data = {
    'x': np.array([100.0]),
}

# Make a prediction using the model
coreml_output = coreml_exp.predict(input_data)
print(coreml_output)


  • Output
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
        42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
        56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.,
        70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83.,
        84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97.,
        98., 99.])
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
        42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
        56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.,
        70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83.,
        84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97.,
        98., 99.])
Converting PyTorch Frontend ==> MIL Ops:   0%|          | 0/9 [00:00<?, ? ops/s]
Converting PyTorch Frontend ==> MIL Ops:  78%|███████▊  | 7/9 [00:00<00:00, 3895.47 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 23275.83 passes/s]
Running MIL default pipeline:   0%|          | 0/66 [00:00<?, ? passes/s]/Users/abhiroop/Developer/aikynetix/full_model_env/lib/python3.10/site-packages/coremltools/converters/mil/mil/passes/defs/preprocess.py:267: UserWarning: Output, 'result.1', of the source model, has been renamed to 'result_1' in the Core ML model.
  warnings.warn(msg.format(var.name, new_name))
Running MIL default pipeline: 100%|██████████| 66/66 [00:00<00:00, 2826.18 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 9124.66 passes/s]
{'result_1': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)}

System environment (please complete the following information):

  • coremltools version: 7.0
  • OS (e.g. MacOS version or Linux type): MacOS 14.0
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.0.1 (same issue for other versions)

Additional Context

The for loop can be easily vectorized, but for simplicity, I wanted to show that the for loops get skipped in coreML conversion.

@the-neural-networker the-neural-networker added the bug Unexpected behaviour that should be corrected (type) label Sep 27, 2023
@TobyRoseman
Copy link
Collaborator

We only have "expiremental" support for PyTorch models which have not been created by torch.jit.trace.

@the-neural-networker
Copy link
Author

The same issue persists if the PyTorch model Experiment is traced, but the function experiment is scripted. Does this still fall under experimental support?

@TobyRoseman
Copy link
Collaborator

@the-neural-networker - the following works for me:

import numpy as np
import torch
import torch.nn as nn
import coremltools as ct

def experiment(val: torch.Tensor):
    val = int(val.item())
    result = torch.zeros(val)
    for i in range(val):
        result[i] = i
    return result

class Experiment(nn.Module):
    def forward(self, x):
        return experiment(x)

exp = Experiment().eval()

# Use a tensor as input, not an integer
input_tensor = torch.tensor(100)
traced_model = torch.jit.trace(exp, input_tensor)
y_t = traced_model(input_tensor)

# Specify the input type as ct.TensorType(name="x", shape=(1,))
coreml_exp = ct.convert(
    traced_model,
    source="pytorch",
    inputs=[ct.TensorType(name="x", shape=(1,))],
    convert_to="mlprogram"
)

# Create an input dictionary with the necessary input data
input_data = {
    'x': np.array([100.0]),
}

# Make a prediction using the model
coreml_output = coreml_exp.predict(input_data)
coreml_output = list(coreml_output.values())[0]

assert all(coreml_output == y_t.numpy())

@the-neural-networker
Copy link
Author

But won't this fail for other inputs (not 100)? Because technically the experiment function is traced which has a for loop and the function's for loop depends on val, it is not static.

@TobyRoseman
Copy link
Collaborator

This example will give the wrong prediction when the input is not 100. However that is a result of the tracing not the conversion to Core ML. The PyTorch traced model gives the wrong the results. The model is converted to Core ML correctly.

If you want to do something like this in Core ML, I think you'll need to write your own mlprogram. Take a look at our set of MIL ops:
https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html
Specifically, I think you'll want to look at the ones under:
coremltools.converters.mil.mil.ops.defs.iOS15.control_flow.*

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (not traced)
Projects
None yet
Development

No branches or pull requests

2 participants