-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Docs:Supplement NPU training script samples and documentation instruction #1169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
507e7e5
07b1f57
8682912
62c3d40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type | ||
| from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name | ||
| from .npu_compatible_device import IS_NPU_AVAILABLE |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||
| from ..utils.lora import GeneralLoRALoader | ||||||
| from ..models.model_loader import ModelPool | ||||||
| from ..utils.controlnet import ControlNetInput | ||||||
| from ..core.device import get_device_name, IS_NPU_AVAILABLE | ||||||
|
|
||||||
|
|
||||||
| class PipelineUnit: | ||||||
|
|
@@ -177,7 +178,7 @@ def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=t | |||||
|
|
||||||
|
|
||||||
| def get_vram(self): | ||||||
| device = self.device if self.device != "npu" else "npu:0" | ||||||
| device = self.device if not IS_NPU_AVAILABLE else get_device_name() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current logic for determining the device is incorrect. If The logic should only use
Suggested change
|
||||||
| return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3) | ||||||
|
|
||||||
| def get_module(self, model, name): | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \ | ||
| --dataset_base_path data/example_image_dataset \ | ||
| --dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \ | ||
| --data_file_keys "image,kontext_images" \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 400 \ | ||
| --model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 1 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/FLUX.1-Kontext-dev_full" \ | ||
| --trainable_models "dit" \ | ||
| --extra_inputs "kontext_images" \ | ||
| --use_gradient_checkpointing |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \ | ||
| --dataset_base_path data/example_image_dataset \ | ||
| --dataset_metadata_path data/example_image_dataset/metadata.csv \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 400 \ | ||
| --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 1 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/FLUX.1-dev_full" \ | ||
| --trainable_models "dit" \ | ||
| --use_gradient_checkpointing |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # Due to memory limitations, split training is required to train the model on NPU | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch examples/qwen_image/model_training/train.py \ | ||
| --dataset_base_path data/example_image_dataset \ | ||
| --dataset_metadata_path data/example_image_dataset/metadata.csv \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 1 \ | ||
| --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:text_encoder/model*.safetensors,Qwen/Qwen-Image-Edit-2509:vae/diffusion_pytorch_model.safetensors" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 5 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache" \ | ||
| --lora_base_model "dit" \ | ||
| --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ | ||
| --lora_rank 32 \ | ||
| --use_gradient_checkpointing \ | ||
| --dataset_num_workers 8 \ | ||
| --find_unused_parameters \ | ||
| --task "sft:data_process" | ||
|
|
||
| accelerate launch examples/qwen_image/model_training/train.py \ | ||
| --dataset_base_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache" \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 50 \ | ||
| --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 5 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited" \ | ||
| --lora_base_model "dit" \ | ||
| --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ | ||
| --lora_rank 32 \ | ||
| --use_gradient_checkpointing \ | ||
| --dataset_num_workers 8 \ | ||
| --find_unused_parameters \ | ||
| --task "sft:train" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # Due to memory limitations, split training is required to train the model on NPU | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch examples/qwen_image/model_training/train.py \ | ||
| --dataset_base_path data/example_image_dataset \ | ||
| --dataset_metadata_path data/example_image_dataset/metadata.csv \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 1 \ | ||
| --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 5 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Qwen-Image-LoRA-splited-cache" \ | ||
| --lora_base_model "dit" \ | ||
| --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ | ||
| --lora_rank 32 \ | ||
| --use_gradient_checkpointing \ | ||
| --dataset_num_workers 8 \ | ||
| --find_unused_parameters \ | ||
| --task "sft:data_process" | ||
|
|
||
| accelerate launch examples/qwen_image/model_training/train.py \ | ||
| --dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 50 \ | ||
| --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 5 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Qwen-Image-LoRA-splited" \ | ||
| --lora_base_model "dit" \ | ||
| --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ | ||
| --lora_rank 32 \ | ||
| --use_gradient_checkpointing \ | ||
| --dataset_num_workers 8 \ | ||
| --find_unused_parameters \ | ||
| --task "sft:train" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset \ | ||
| --dataset_metadata_path data/example_video_dataset/metadata.csv \ | ||
| --height 480 \ | ||
| --width 832 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 2 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Wan2.1-T2V-14B_full" \ | ||
| --trainable_models "dit" \ | ||
| --initialize_model_on_cpu |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset \ | ||
| --dataset_metadata_path data/example_video_dataset/metadata.csv \ | ||
| --height 480 \ | ||
| --width 832 \ | ||
| --num_frames 49 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 2 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \ | ||
| --trainable_models "dit" \ | ||
| --max_timestep_boundary 0.417 \ | ||
| --min_timestep_boundary 0 \ | ||
| --initialize_model_on_cpu | ||
| # boundary corresponds to timesteps [875, 1000] | ||
|
|
||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset \ | ||
| --dataset_metadata_path data/example_video_dataset/metadata.csv \ | ||
| --height 480 \ | ||
| --width 832 \ | ||
| --num_frames 49 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 2 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \ | ||
| --trainable_models "dit" \ | ||
| --max_timestep_boundary 1 \ | ||
| --min_timestep_boundary 0.417 \ | ||
| --initialize_model_on_cpu | ||
| # boundary corresponds to timesteps [0, 875) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset \ | ||
| --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ | ||
| --data_file_keys "video,vace_video,vace_reference_image" \ | ||
| --height 480 \ | ||
| --width 832 \ | ||
| --num_frames 17 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 2 \ | ||
| --remove_prefix_in_ckpt "pipe.vace." \ | ||
| --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \ | ||
| --trainable_models "vace" \ | ||
| --extra_inputs "vace_video,vace_reference_image" \ | ||
| --use_gradient_checkpointing_offload \ | ||
| --max_timestep_boundary 0.358 \ | ||
| --min_timestep_boundary 0 \ | ||
| --initialize_model_on_cpu | ||
| # boundary corresponds to timesteps [900, 1000] | ||
|
|
||
|
|
||
| accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ | ||
| --dataset_base_path data/example_video_dataset \ | ||
| --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ | ||
| --data_file_keys "video,vace_video,vace_reference_image" \ | ||
| --height 480 \ | ||
| --width 832 \ | ||
| --num_frames 17 \ | ||
| --dataset_repeat 100 \ | ||
| --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ | ||
| --learning_rate 1e-4 \ | ||
| --num_epochs 2 \ | ||
| --remove_prefix_in_ckpt "pipe.vace." \ | ||
| --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \ | ||
| --trainable_models "vace" \ | ||
| --extra_inputs "vace_video,vace_reference_image" \ | ||
| --use_gradient_checkpointing_offload \ | ||
| --max_timestep_boundary 1 \ | ||
| --min_timestep_boundary 0.358 \ | ||
| --initialize_model_on_cpu | ||
| # boundary corresponds to timesteps [0, 900] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True | ||
| export CPU_AFFINITY_CONF=1 | ||
|
|
||
| accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ | ||
| --dataset_base_path data/example_image_dataset \ | ||
| --dataset_metadata_path data/example_image_dataset/metadata.csv \ | ||
| --max_pixels 1048576 \ | ||
| --dataset_repeat 400 \ | ||
| --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ | ||
| --learning_rate 1e-5 \ | ||
| --num_epochs 2 \ | ||
| --remove_prefix_in_ckpt "pipe.dit." \ | ||
| --output_path "./models/train/Z-Image-Turbo_full" \ | ||
| --trainable_models "dit" \ | ||
| --use_gradient_checkpointing \ | ||
| --dataset_num_workers 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for determining the device is incorrect. If
IS_NPU_AVAILABLEis true, it will always useget_device_name()to get an NPU device, even ifself.computation_deviceis set to a CUDA device. This will causegetattr(torch, self.computation_device_type).mem_get_info(device)to fail whenself.computation_device_typeis 'cuda' butdeviceis an NPU device string.The logic should only use
get_device_name()when the computation device is specifically 'npu' to resolve it to a full device name like 'npu:0', while preserving other device specifications like 'cuda' or 'npu:1'.