Skip to content

Commit

Permalink
upload full model
Browse files Browse the repository at this point in the history
  • Loading branch information
flamehaze1115 committed Nov 28, 2023
1 parent 9f7e4df commit f3f7015
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 49 deletions.
94 changes: 67 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,45 @@ Single Image to 3D using Cross-Domain Diffusion

Wonder3D reconstructs highly-detailed textured meshes from a single-view image in only 2 ∼ 3 minutes. Wonder3D first generates consistent multi-view normal maps with corresponding color images via a cross-domain diffusion model, and then leverages a novel normal fusion method to achieve fast and high-quality reconstruction.

## Share your reconstructions!
If you get any interesting reconstructions and would like to share with others, welcome to upload the input image and reconstructed mesh to this [onedrive repo](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/EvIYSlYlkwNJpQkEnlD7uHsB-C6zU1oSNjqAGz_K7VGG2Q).
## Usage
```bash

import torch
import requests
from PIL import Image
import numpy as np
from torchvision.utils import make_grid, save_image
from diffusers import DiffusionPipeline

Data structure:
```
{yourname}/{scenename}-input.png # the input image
{yourname}/{scenename}-screenshot.png # a front view screenshot of the reconstructed mesh
{yourname}/{scenename}-mesh.obj # the reconstructed mesh, .obj or .ply
# example:
# create a folder name `xxlong`, then upload the files to the folder
xxlong/apple-input.png
xxlong/apple-screenshot.png
xxlong/apple-mesh.obj
def load_wonder3d_pipeline():

pipeline = DiffusionPipeline.from_pretrained(
'flamehaze1115/wonder3d-v1.0', # or use local checkpoint './ckpts'
custom_pipeline='flamehaze1115/wonder3d-pipeline',
torch_dtype=torch.float16
)

# enable xformers
pipeline.unet.enable_xformers_memory_efficient_attention()

if torch.cuda.is_available():
pipeline.to('cuda:0')
return pipeline

pipeline = load_wonder3d_pipeline()

# Download an example image.
cond = Image.open(requests.get("https://d.skis.ltd/nrp/sample-data/lysol.png", stream=True).raw)

# The object should be located in the center and resized to 80% of image height.
cond = Image.fromarray(np.array(cond)[:, :, :3])

# Run the pipeline!
images = pipeline(cond, num_inference_steps=20, output_type='pt').images

result = make_grid(images, nrow=6, ncol=2, padding=0, value_range=(0, 1))

save_image(result, 'result.png')
```
## Collaborations
Expand All @@ -30,15 +54,15 @@ Our overarching mission is to enhance the speed, affordability, and quality of 3
The repo is still being under construction, thanks for your patience.
- [x] Local gradio demo.
- [ ] Detailed tutorial.
- [x] Detailed tutorial.
- [x] GUI demo for mesh reconstruction
- [x] Windows support
- [ ] Docker support
## Schedule
- [x] Inference code and pretrained models.
- [x] Huggingface demo.
- [ ] New model trained on the whole Objaverse dataset.
- [ ] New model with higher resolution.
### Preparation for inference
Expand All @@ -54,55 +78,71 @@ pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/to
Please switch to branch `main-windows` to see details of windows setup.
#### Download pre-trained checkpoint.
Download the [checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/EgSHPyJAtaJFpV_BjXM3zXwB-UMIrT4v-sQwGgw-coPtIA) and into the root folder.
### Inference
1. Make sure you have the following models.
1. Optional. If you have troubles to connect to huggingface. Make sure you have downloaded the following models.
Download the [checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/Ej7fMT1PwXtKvsELTvDuzuMBebQXEkmf2IwhSjBWtKAJiA) and into the root folder.
```bash
Wonder3D
|-- ckpts
|-- unet
|-- scheduler.bin
|-- scheduler
|-- vae
...
```
* Download the [SAM](https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth) model. Put it to the ``sam_pt`` folder.
Then modify the file ./configs/mvdiffusion-joint-ortho-6views.yaml, set `pretrained_model_name_or_path="./ckpts"`
2. Download the [SAM](https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth) model. Put it to the ``sam_pt`` folder.
```
Wonder3D
|-- sam_pt
|-- sam_vit_h_4b8939.pth
```
2. Predict foreground mask as the alpha channel. We use [Clipdrop](https://clipdrop.co/remove-background) to segment the foreground object interactively.
3. Predict foreground mask as the alpha channel. We use [Clipdrop](https://clipdrop.co/remove-background) to segment the foreground object interactively.
You may also use `rembg` to remove the backgrounds.
```bash
# !pip install rembg
import rembg
result = rembg.remove(result)
result.show()
```
3. Run Wonder3d to produce multiview-consistent normal maps and color images. Then you can check the results in the folder `./outputs`. (we use `rembg` to remove backgrounds of the results, but the segmentations are not always perfect. May consider using [Clipdrop](https://clipdrop.co/remove-background) to get masks for the generated normal maps and color images, since the quality of masks will significantly influence the reconstructed mesh quality.)
4. Run Wonder3d to produce multiview-consistent normal maps and color images. Then you can check the results in the folder `./outputs`. (we use `rembg` to remove backgrounds of the results, but the segmentations are not always perfect. May consider using [Clipdrop](https://clipdrop.co/remove-background) to get masks for the generated normal maps and color images, since the quality of masks will significantly influence the reconstructed mesh quality.)
```bash
accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \
--config mvdiffusion-joint-ortho-6views.yaml
--config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir={your_data_path} \
validation_dataset.filepaths=['your_img_file'] save_dir={your_save_path}
```
or
see example:
```bash
bash run_test.sh
accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \
--config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir=./example_images \
validation_dataset.filepaths=['owl.png'] save_dir=./outputs
```
#### Interactive inference: run your local gradio demo
```bash
python gradio_app_mv.py # generate multi-view normals and colors
```
4. Mesh Extraction
5. Mesh Extraction
#### Instant-NSR Mesh Extraction
```bash
cd ./instant-nsr-pl
bash run.sh output_folder_path scene_name
python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{your_save_path}/cropsize-{crop_size}-cfg{guidance_scale:.1f}/ dataset.scene={scene}
```
see example:
```bash
cd ./instant-nsr-pl
python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../outputs/cropsize-192-cfg1.0/ dataset.scene=owl
```
Our generated normals and color images are defined in orthographic views, so the reconstructed mesh is also in orthographic camera space. If you use MeshLab to view the meshes, you can click `Toggle Orthographic Camera` in `View` tab.
#### Interactive inference: run your local gradio demo
Expand Down
7 changes: 3 additions & 4 deletions configs/mvdiffusion-joint-ortho-6views.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pretrained_model_name_or_path: './ckpts'
# pretrained_unet_path: './ckpts'
pretrained_model_name_or_path: 'flamehaze1115/wonder3d-v1.0' # or './ckpts'
revision: null
validation_dataset:
root_dir: "./example_images" # the folder path stores testing images
Expand All @@ -23,7 +22,7 @@ pipe_kwargs:
camera_embedding_type: 'e_de_da_sincos'
num_views: 6

validation_guidance_scales: [3.0]
validation_guidance_scales: [1.0]
pipe_validation_kwargs:
eta: 1.0
validation_grid_nrow: 6
Expand All @@ -33,7 +32,7 @@ unet_from_pretrained_kwargs:
projection_class_embeddings_input_dim: 10
num_views: 6
sample_size: 32
cd_attention_mid: True
cd_attention_mid: true
zero_init_conv_in: false
zero_init_camera_projection: false

Expand Down
6 changes: 3 additions & 3 deletions gradio_app_mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
def load_wonder3d_pipeline(cfg):

pipeline = MVDiffusionImagePipeline.from_pretrained(
"./wonder3D-model",
cfg.pretrained_model_name_or_path,
torch_dtype=weight_dtype
)

Expand Down Expand Up @@ -210,7 +210,7 @@ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_

out = pipeline(
imgs_in,
camera_embeddings,
# camera_embeddings,
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=steps,
Expand All @@ -225,7 +225,7 @@ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_
num_views = 6
if write_image:
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
cur_dir = os.path.join("./outputs", f"cropsize-{crop_size}-cfg{guidance_scale:.1f}")
cur_dir = os.path.join("./outputs", f"cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}")

scene = 'scene'+datetime.now().strftime('@%Y%m%d-%H%M%S')
scene_dir = os.path.join(cur_dir, scene)
Expand Down
8 changes: 4 additions & 4 deletions gradio_app_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
def load_wonder3d_pipeline(cfg):

pipeline = MVDiffusionImagePipeline.from_pretrained(
"./ckpts",
cfg.pretrained_model_name_or_path,
torch_dtype=weight_dtype
)

Expand Down Expand Up @@ -225,7 +225,7 @@ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_
num_views = 6
if write_image:
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
cur_dir = os.path.join("./outputs", f"cropsize-{crop_size}-cfg{guidance_scale:.1f}")
cur_dir = os.path.join("./outputs", f"cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}")

scene = 'scene'+datetime.now().strftime('@%Y%m%d-%H%M%S')
scene_dir = os.path.join(cur_dir, scene)
Expand Down Expand Up @@ -263,7 +263,7 @@ def process_3d(mode, data_dir, guidance_scale, crop_size):
cur_dir = os.path.dirname(os.path.abspath(__file__))

subprocess.run(
f'cd instant-nsr-pl && python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{data_dir}/cropsize-{crop_size:.1f}-cfg{guidance_scale:.1f}/ dataset.scene={scene} && cd ..',
f'cd instant-nsr-pl && python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{data_dir}/cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}/ dataset.scene={scene} && cd ..',
shell=True,
)
import glob
Expand Down Expand Up @@ -404,7 +404,7 @@ def run_demo():
# method = gr.Radio(choices=['instant-nsr-pl', 'NeuS'], label='Method (Default: instant-nsr-pl)', value='instant-nsr-pl')
# run_btn = gr.Button('Generate Normals and Colors', variant='primary', interactive=True)
run_btn = gr.Button('Reconstruct 3D model', variant='primary', interactive=True)
gr.Markdown("<span style='color:red'> Reconstruction may cost several minutes.</span>")
gr.Markdown("<span style='color:red'> Reconstruction may cost several minutes. Check results in instant-nsr-pl/exp/scene@{current-time}/ </span>")

with gr.Row():
view_1 = gr.Image(interactive=False, height=240, show_label=False)
Expand Down
24 changes: 19 additions & 5 deletions mvdiffusion/models/transformer_mv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.import_utils import is_xformers_available

from einops import rearrange
from einops import rearrange, repeat
import pdb
import random

Expand All @@ -38,6 +38,15 @@
else:
xformers = None

def my_repeat(tensor, num_repeats):
"""
Repeat a tensor along a given dimension
"""
if len(tensor.shape) == 3:
return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
elif len(tensor.shape) == 4:
return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)


@dataclass
class TransformerMV2DModelOutput(BaseOutput):
Expand Down Expand Up @@ -501,7 +510,7 @@ def __init__(
self.cd_attention_mid = cd_attention_mid

if self.cd_attention_mid:
print("cross-domain attn in the middle")
# print("cross-domain attn in the middle")
# Joint task -Attn
self.attn_joint_mid = CustomJointAttention(
query_dim=dim,
Expand Down Expand Up @@ -772,9 +781,14 @@ def __call__(
# pdb.set_trace()
# multi-view self-attention
if multiview_attention:

key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
if not sparse_mv_attention:
key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
else:
key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
value = torch.cat([value_front, value_raw], dim=1)

else:
# print("don't use multiview attention.")
Expand Down
36 changes: 30 additions & 6 deletions mvdiffusion/pipelines/pipeline_mvdiffusion_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

from einops import rearrange, repeat

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
camera_embedding_type: str = 'e_de_da_sincos',
num_views: int = 4
num_views: int = 6
):
super().__init__()

Expand Down Expand Up @@ -133,6 +133,20 @@ def __init__(
self.camera_embedding_type: str = camera_embedding_type
self.num_views: int = num_views

self.camera_embedding = torch.tensor(
[[ 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
[ 0.0000, -0.2362, 0.8125, 1.0000, 0.0000],
[ 0.0000, -0.1686, 1.6934, 1.0000, 0.0000],
[ 0.0000, 0.5220, 3.1406, 1.0000, 0.0000],
[ 0.0000, 0.6904, 4.8359, 1.0000, 0.0000],
[ 0.0000, 0.3733, 5.5859, 1.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
[ 0.0000, -0.2362, 0.8125, 0.0000, 1.0000],
[ 0.0000, -0.1686, 1.6934, 0.0000, 1.0000],
[ 0.0000, 0.5220, 3.1406, 0.0000, 1.0000],
[ 0.0000, 0.6904, 4.8359, 0.0000, 1.0000],
[ 0.0000, 0.3733, 5.5859, 0.0000, 1.0000]], dtype=torch.float16)

def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype

Expand Down Expand Up @@ -288,7 +302,7 @@ def __call__(
# elevation_cond: torch.FloatTensor,
# elevation: torch.FloatTensor,
# azimuth: torch.FloatTensor,
camera_embedding: torch.FloatTensor,
camera_embedding: Optional[torch.FloatTensor]=None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
Expand Down Expand Up @@ -384,10 +398,15 @@ def __call__(
# 2. Define call parameters
if isinstance(image, list):
batch_size = len(image)
else:
elif isinstance(image, torch.Tensor):
batch_size = image.shape[0]
assert batch_size >= self.num_views and batch_size % self.num_views == 0
assert batch_size >= self.num_views and batch_size % self.num_views == 0
elif isinstance(image, PIL.Image.Image):
image = [image]*self.num_views*2
batch_size = self.num_views*2

device = self._execution_device
dtype = self.vae.dtype
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand All @@ -410,7 +429,12 @@ def __call__(

# assert len(elevation_cond) == batch_size and len(elevation) == batch_size and len(azimuth) == batch_size
# camera_embeddings = self.prepare_camera_condition(elevation_cond, elevation, azimuth, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
assert len(camera_embedding) == batch_size

if camera_embedding is not None:
assert len(camera_embedding) == batch_size
else:
camera_embedding = self.camera_embedding.to(dtype)
camera_embedding = repeat(camera_embedding, "Nv Nce -> (B Nv) Nce", B=batch_size//len(camera_embedding))
camera_embeddings = self.prepare_camera_embedding(camera_embedding, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)

# 4. Prepare timesteps
Expand Down

0 comments on commit f3f7015

Please sign in to comment.