Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When converting the sdxl unet weight file to a single file, the inference result is wrong #5765

Open
williamlzw opened this issue Nov 22, 2023 · 4 comments
Labels

Comments

@williamlzw
Copy link

williamlzw commented Nov 22, 2023

The decentralized onnx weight inference result exported by pytorch is correct, but the inference result is wrong after saving it as a single weight file using onnx.

==================
model link:
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/unet/diffusion_pytorch_model.safetensors

code link:
https://github.com/cloneofsimo/minDiffusion/tree/master/mindiffusion/unet.py

test code:

    def export_unet1():
        from diffusers import StableDiffusionXLPipeline
        from unet import UNet2DConditionModel
        pipe = StableDiffusionXLPipeline.from_pretrained("G:/stable-model/stable-diffusion-xl", use_safetensors=True)
        model = UNet2DConditionModel()
        model.load_state_dict(pipe.unet.state_dict())
        model.eval()
        latents = torch.randn(2, 4, 128, 128)
        timesteps = torch.tensor([77], dtype=torch.int64)
        prompt_embeds = torch.randn(2, 77, 2048)
        text_embeds = torch.randn(2, 1280)
        time_ids = torch.randn(2, 6)
        input_list = ['latents', 'timesteps', 'prompt_embeds', 'text_embeds', 'time_ids']
        output_list = ['noise_pred']
        torch.onnx.export(model, (latents, timesteps, prompt_embeds, text_embeds, time_ids), f="./onnxmodel/unet1.onnx", 
                          input_names=input_list, output_names=output_list,do_constant_folding=True)
        unet_onnx = onnx.load("./onnxmodel/unet1.onnx")
        onnx.save_model(
                unet_onnx,
                "./onnxmodel/unet.onnx",
                save_as_external_data=True,
                all_tensors_to_one_file=True,
                location="unet.pb",
                size_threshold = 0,
                convert_attribute=True,
            )

[ONNXRuntimeError] : 1 : FAIL : Load model from ./onnxmodel1/unet.onnx failed:Node (/Unsqueeze) Op (Unsqueeze) [ShapeInferenceError] Cannot parse data from external tensors. Please load external data into raw data for tensor: /Constant_2_output_0

    def test_unet_onnx():
        unet = onnxruntime.InferenceSession("./onnxmodel/unet.onnx")#unet.onnx
        latents = np.ones((2, 4, 128, 128)).astype(np.float32)
        timesteps = np.array([77]).astype(np.int64)
        prompt_embeds = np.ones((2, 77, 2048)).astype(np.float32)
        text_embeds = np.ones((2, 1280)).astype(np.float32)
        time_ids = np.ones((2, 6)).astype(np.float32)
        input_list = {'latents':latents, 'timesteps':timesteps, 'prompt_embeds':prompt_embeds, 
                      'text_embeds':text_embeds, 'time_ids':time_ids}
        output_list = ['noise_pred']
        noise_pred = unet.run(output_list, input_list)[0]
        print('noise_pred', noise_pred.shape)
        print(noise_pred)# not correct


    def test_unet_onnx1():
        unet = onnxruntime.InferenceSession("./onnxmodel/unet1.onnx")#unet1.onnx
        latents = np.ones((2, 4, 128, 128)).astype(np.float32)
        timesteps = np.array([77]).astype(np.int64)
        prompt_embeds = np.ones((2, 77, 2048)).astype(np.float32)
        text_embeds = np.ones((2, 1280)).astype(np.float32)
        time_ids = np.ones((2, 6)).astype(np.float32)
        input_list = {'latents':latents, 'timesteps':timesteps, 'prompt_embeds':prompt_embeds, 
                      'text_embeds':text_embeds, 'time_ids':time_ids}
        output_list = ['noise_pred']
        noise_pred = unet.run(output_list, input_list)[0]
        print('noise_pred', noise_pred.shape)
        print(noise_pred)# correct
@williamlzw williamlzw added the bug label Nov 22, 2023
@justinchuby justinchuby closed this as not planned Won't fix, can't repro, duplicate, stale Nov 29, 2023
@justinchuby justinchuby reopened this Nov 29, 2023
@justinchuby
Copy link
Contributor

Is it causing the ONNX runtime error when you say the inference result is wrong? Does onnx runtime produce a result at all?

@williamlzw
Copy link
Author

When the weight is saved as a single onnx model, the onnxruntime inference error is reported and the result cannot be obtained.

@justinchuby
Copy link
Contributor

Try relaxing the size_threshold a bit. Something like 1024 may be appropriate.

@williamlzw
Copy link
Author

I tried it, same error.

@github-actions github-actions bot added the stale label Nov 30, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Dec 22, 2024
@andife andife reopened this Dec 22, 2024
@github-actions github-actions bot removed the stale label Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants