Skip to content

Commit f546404

Browse files
jellyheadandrewsayakpaul
authored andcommitted
Add new community pipeline for 'Adaptive Mask Inpainting', introduced in [ECCV2024] ComA (#9228)
* Add new community pipeline for 'Adaptive Mask Inpainting', introduced in [ECCV2024] Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models
1 parent d1c42c6 commit f546404

File tree

2 files changed

+1621
-0
lines changed

2 files changed

+1621
-0
lines changed

examples/community/README.md

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
1010

1111
| Example | Description | Code Example | Colab | Author |
1212
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
13+
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
1314
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
1415
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
1516
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
@@ -85,6 +86,161 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
8586

8687
## Example usages
8788

89+
### Adaptive Mask Inpainting
90+
91+
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
92+
93+
**Seoul National University, Naver Webtoon**
94+
95+
Adaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene.
96+
97+
Here is the demonstration of Adaptive Mask Inpainting:
98+
99+
<video controls>
100+
<source src="https://snuvclab.github.io/coma/static/videos/adaptive_mask_inpainting_vis.mp4" type="video/mp4">
101+
Your browser does not support the video tag.
102+
</video>
103+
104+
![teaser-img](https://snuvclab.github.io/coma/static/images/example_result_adaptive_mask_inpainting.png)
105+
106+
107+
You can find additional information about Adaptive Mask Inpainting in the [paper](https://arxiv.org/pdf/2401.12978) or in the [project website](https://snuvclab.github.io/coma).
108+
109+
#### Usage example
110+
First, clone the diffusers github repository, and run the following command to set environment.
111+
```Shell
112+
git clone https://github.com/huggingface/diffusers.git
113+
cd diffusers
114+
115+
conda create --name ami python=3.9 -y
116+
conda activate ami
117+
118+
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y
119+
python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
120+
pip install easydict
121+
pip install diffusers==0.20.2 accelerate safetensors transformers
122+
pip install setuptools==59.5.0
123+
pip install opencv-python
124+
pip install numpy==1.24.1
125+
```
126+
Then, run the below code under 'diffusers' directory.
127+
```python
128+
import numpy as np
129+
import torch
130+
from PIL import Image
131+
132+
from diffusers import DDIMScheduler
133+
from diffusers import DiffusionPipeline
134+
from diffusers.utils import load_image
135+
136+
from examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE
137+
138+
print(AMI_INSTALL_MESSAGE)
139+
140+
from easydict import EasyDict
141+
142+
143+
144+
if __name__ == "__main__":
145+
"""
146+
Download Necessary Files
147+
"""
148+
download_file(
149+
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true",
150+
output_file = "model_final_edd263.pkl",
151+
exist_ok=True,
152+
)
153+
download_file(
154+
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true",
155+
output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml",
156+
exist_ok=True,
157+
)
158+
download_file(
159+
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true",
160+
output_file = "input_img.png",
161+
exist_ok=True,
162+
)
163+
download_file(
164+
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true",
165+
output_file = "input_mask.png",
166+
exist_ok=True,
167+
)
168+
download_file(
169+
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true",
170+
output_file = "Base-PointRend-RCNN-FPN.yaml",
171+
exist_ok=True,
172+
)
173+
download_file(
174+
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true",
175+
output_file = "Base-RCNN-FPN.yaml",
176+
exist_ok=True,
177+
)
178+
179+
"""
180+
Prepare Adaptive Mask Inpainting Pipeline
181+
"""
182+
# device
183+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
184+
num_steps = 50
185+
186+
# Scheduler
187+
scheduler = DDIMScheduler(
188+
beta_start=0.00085,
189+
beta_end=0.012,
190+
beta_schedule="scaled_linear",
191+
clip_sample=False,
192+
set_alpha_to_one=False
193+
)
194+
scheduler.set_timesteps(num_inference_steps=num_steps)
195+
196+
## load models as pipelines
197+
pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(
198+
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
199+
scheduler=scheduler,
200+
torch_dtype=torch.float16,
201+
requires_safety_checker=False
202+
).to(device)
203+
204+
## disable safety checker
205+
enable_safety_checker = False
206+
if not enable_safety_checker:
207+
pipeline.safety_checker = None
208+
209+
"""
210+
Run Adaptive Mask Inpainting
211+
"""
212+
default_mask_image = Image.open("./input_mask.png").convert("L")
213+
init_image = Image.open("./input_img.png").convert("RGB")
214+
215+
216+
seed = 59
217+
generator = torch.Generator(device=device)
218+
generator.manual_seed(seed)
219+
220+
image = pipeline(
221+
prompt="a man sitting on a couch",
222+
negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw",
223+
image=init_image,
224+
default_mask_image=default_mask_image,
225+
guidance_scale=11.0,
226+
strength=0.98,
227+
use_adaptive_mask=True,
228+
generator=generator,
229+
enforce_full_mask_ratio=0.0,
230+
visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T CHANGE THIS!!!
231+
human_detection_thres=0.015,
232+
).images[0]
233+
234+
235+
image.save(f'final_img.png')
236+
```
237+
#### [Troubleshooting]
238+
239+
If you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`.
240+
241+
For example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`.
242+
243+
88244
### Flux with CFG
89245

90246
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).

0 commit comments

Comments
 (0)