import copy import os import re import torch, os, imageio, argparse from torchvision.transforms import v2 from einops import rearrange import lightning as pl import pandas as pd from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict import torchvision from PIL import Image import numpy as np import random import json import torch.nn as nn import torch.nn.functional as F import shutil class TextVideoDataset(torch.utils.data.Dataset): def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): metadata = pd.read_csv(metadata_path) self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] self.text = metadata["text"].to_list() self.max_num_frames = max_num_frames self.frame_interval = frame_interval self.num_frames = num_frames self.height = height self.width = width self.is_i2v = is_i2v self.frame_process = v2.Compose([ v2.CenterCrop(size=(height, width)), v2.Resize(size=(height, width), antialias=True), v2.ToTensor(), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) def crop_and_resize(self, image): width, height = image.size scale = max(self.width / width, self.height / height) image = torchvision.transforms.functional.resize( image, (round(height*scale), round(width*scale)), interpolation=torchvision.transforms.InterpolationMode.BILINEAR ) return image def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): reader = imageio.get_reader(file_path) if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: reader.close() return None frames = [] first_frame = None for frame_id in range(num_frames): frame = reader.get_data(start_frame_id + frame_id * interval) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame) if first_frame is None: first_frame = np.array(frame) frame = frame_process(frame) frames.append(frame) reader.close() frames = torch.stack(frames, dim=0) frames = rearrange(frames, "T C H W -> C T H W") if self.is_i2v: return frames, first_frame else: return frames def load_video(self, file_path): start_frame_id = 0 frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) return frames def is_image(self, file_path): file_ext_name = file_path.split(".")[-1] if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: return True return False def load_image(self, file_path): frame = Image.open(file_path).convert("RGB") frame = self.crop_and_resize(frame) first_frame = frame frame = self.frame_process(frame) frame = rearrange(frame, "C H W -> C 1 H W") return frame def __getitem__(self, data_id): while True: try: text = self.text[data_id] path = self.path[data_id] if self.is_image(path): if self.is_i2v: raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") video = self.load_image(path) else: video = self.load_video(path) if self.is_i2v: video, first_frame = video data = {"text": text, "video": video, "path": path, "first_frame": first_frame} else: data = {"text": text, "video": video, "path": path} break except: data_id += 1 return data def __len__(self): return len(self.path) class LightningModelForDataProcess(pl.LightningModule): def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): super().__init__() model_path = [text_encoder_path, vae_path] if image_encoder_path is not None: model_path.append(image_encoder_path) model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") model_manager.load_models(model_path) self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} def test_step(self, batch, batch_idx): text, video, path = batch["text"][0], batch["video"], batch["path"][0] self.pipe.device = self.device if video is not None: pth_path = path + ".tensors.pth" if not os.path.exists(pth_path): # prompt prompt_emb = self.pipe.encode_prompt(text) # video video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] # image if "first_frame" in batch: first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) _, _, num_frames, height, width = video.shape image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) else: image_emb = {} data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} torch.save(data, pth_path) else: print(f"File {pth_path} already exists, skipping.") class Camera(object): def __init__(self, c2w): c2w_mat = np.array(c2w).reshape(4, 4) self.c2w_mat = c2w_mat self.w2c_mat = np.linalg.inv(c2w_mat) class TensorDataset(torch.utils.data.Dataset): def __init__(self, base_path, metadata_path, steps_per_epoch): metadata = pd.read_csv(metadata_path) self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] print(len(self.path), "videos in metadata.") self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] print(len(self.path), "tensors cached in metadata.") assert len(self.path) > 0 self.steps_per_epoch = steps_per_epoch def parse_matrix(self, matrix_str): rows = matrix_str.strip().split('] [') matrix = [] for row in rows: row = row.replace('[', '').replace(']', '') matrix.append(list(map(float, row.split()))) return np.array(matrix) def get_relative_pose(self, cam_params): abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] target_cam_c2w = np.array([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ]) abs2rel = target_cam_c2w @ abs_w2cs[0] ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] ret_poses = np.array(ret_poses, dtype=np.float32) return ret_poses def __getitem__(self, index): # Return: # data['latents']: torch.Size([16, 21*2, 60, 104]) # data['camera']: torch.Size([21, 3, 4]) # data['prompt_emb']["context"][0]: torch.Size([512, 4096]) while True: try: data = {} data_id = torch.randint(0, len(self.path), (1,))[0] data_id = (data_id + index) % len(self.path) # For fixed seed. path_tgt = self.path[data_id] data_tgt = torch.load(path_tgt, weights_only=True, map_location="cpu") # load the condition latent match = re.search(r'cam(\d+)', path_tgt) tgt_idx = int(match.group(1)) cond_idx = random.randint(1, 10) while cond_idx == tgt_idx: cond_idx = random.randint(1, 10) path_cond = re.sub(r'cam(\d+)', f'cam{cond_idx:02}', path_tgt) data_cond = torch.load(path_cond, weights_only=True, map_location="cpu") data['latents'] = torch.cat((data_tgt['latents'],data_cond['latents']),dim=1) data['prompt_emb'] = data_tgt['prompt_emb'] data['image_emb'] = {} # load the target trajectory base_path = path_tgt.rsplit('/', 2)[0] tgt_camera_path = os.path.join(base_path, "cameras", "camera_extrinsics.json") with open(tgt_camera_path, 'r') as file: cam_data = json.load(file) multiview_c2ws = [] cam_idx = list(range(81))[::4] for view_idx in [cond_idx, tgt_idx]: traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{view_idx:02d}"]) for idx in cam_idx] traj = np.stack(traj).transpose(0, 2, 1) c2ws = [] for c2w in traj: c2w = c2w[:, [1, 2, 0, 3]] c2w[:3, 1] *= -1. c2w[:3, 3] /= 100 c2ws.append(c2w) multiview_c2ws.append(c2ws) cond_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[0]] tgt_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[1]] relative_poses = [] for i in range(len(tgt_cam_params)): relative_pose = self.get_relative_pose([cond_cam_params[0], tgt_cam_params[i]]) relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') data['camera'] = pose_embedding.to(torch.bfloat16) break except Exception as e: print(f"ERROR WHEN LOADING: {e}") index = random.randrange(len(self.path)) return data def __len__(self): return self.steps_per_epoch class LightningModelForTrain(pl.LightningModule): def __init__( self, dit_path, learning_rate=1e-5, use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, resume_ckpt_path=None ): super().__init__() model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") if os.path.isfile(dit_path): model_manager.load_models([dit_path]) else: dit_path = dit_path.split(",") model_manager.load_models([dit_path]) self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) self.pipe.scheduler.set_timesteps(1000, training=True) dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] for block in self.pipe.dit.blocks: block.cam_encoder = nn.Linear(12, dim) block.projector = nn.Linear(dim, dim) block.cam_encoder.weight.data.zero_() block.cam_encoder.bias.data.zero_() block.projector.weight = nn.Parameter(torch.eye(dim)) block.projector.bias = nn.Parameter(torch.zeros(dim)) if resume_ckpt_path is not None: state_dict = torch.load(resume_ckpt_path, map_location="cpu") self.pipe.dit.load_state_dict(state_dict, strict=True) self.freeze_parameters() for name, module in self.pipe.denoising_model().named_modules(): if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): print(f"Trainable: {name}") for param in module.parameters(): param.requires_grad = True trainable_params = 0 seen_params = set() for name, module in self.pipe.denoising_model().named_modules(): for param in module.parameters(): if param.requires_grad and param not in seen_params: trainable_params += param.numel() seen_params.add(param) print(f"Total number of trainable parameters: {trainable_params}") self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload def freeze_parameters(self): # Freeze parameters self.pipe.requires_grad_(False) self.pipe.eval() self.pipe.denoising_model().train() def training_step(self, batch, batch_idx): # Data latents = batch["latents"].to(self.device) prompt_emb = batch["prompt_emb"] prompt_emb["context"] = prompt_emb["context"][0].to(self.device) image_emb = batch["image_emb"] if "clip_feature" in image_emb: image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) if "y" in image_emb: image_emb["y"] = image_emb["y"][0].to(self.device) cam_emb = batch["camera"].to(self.device) # Loss self.pipe.device = self.device noise = torch.randn_like(latents) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) extra_input = self.pipe.prepare_extra_input(latents) origin_latents = copy.deepcopy(latents) noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) tgt_latent_len = noisy_latents.shape[2] // 2 noisy_latents[:, :, tgt_latent_len:, ...] = origin_latents[:, :, tgt_latent_len:, ...] training_target = self.pipe.scheduler.training_target(latents, noise, timestep) # Compute loss noise_pred = self.pipe.denoising_model()( noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, use_gradient_checkpointing=self.use_gradient_checkpointing, use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload ) loss = torch.nn.functional.mse_loss(noise_pred[:, :, :tgt_latent_len, ...].float(), training_target[:, :, :tgt_latent_len, ...].float()) loss = loss * self.pipe.scheduler.training_weight(timestep) # Record log self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self): trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) return optimizer def on_save_checkpoint(self, checkpoint): checkpoint_dir = self.trainer.checkpoint_callback.dirpath print(f"Checkpoint directory: {checkpoint_dir}") current_step = self.global_step print(f"Current step: {current_step}") checkpoint.clear() trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) state_dict = self.pipe.denoising_model().state_dict() torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) def parse_args(): parser = argparse.ArgumentParser(description="Train ReCamMaster") parser.add_argument( "--task", type=str, default="data_process", required=True, choices=["data_process", "train"], help="Task. `data_process` or `train`.", ) parser.add_argument( "--dataset_path", type=str, default=None, required=True, help="The path of the Dataset.", ) parser.add_argument( "--output_path", type=str, default="./", help="Path to save the model.", ) parser.add_argument( "--text_encoder_path", type=str, default=None, help="Path of text encoder.", ) parser.add_argument( "--image_encoder_path", type=str, default=None, help="Path of image encoder.", ) parser.add_argument( "--vae_path", type=str, default=None, help="Path of VAE.", ) parser.add_argument( "--dit_path", type=str, default=None, help="Path of DiT.", ) parser.add_argument( "--tiled", default=False, action="store_true", help="Whether enable tile encode in VAE. This option can reduce VRAM required.", ) parser.add_argument( "--tile_size_height", type=int, default=34, help="Tile size (height) in VAE.", ) parser.add_argument( "--tile_size_width", type=int, default=34, help="Tile size (width) in VAE.", ) parser.add_argument( "--tile_stride_height", type=int, default=18, help="Tile stride (height) in VAE.", ) parser.add_argument( "--tile_stride_width", type=int, default=16, help="Tile stride (width) in VAE.", ) parser.add_argument( "--steps_per_epoch", type=int, default=500, help="Number of steps per epoch.", ) parser.add_argument( "--num_frames", type=int, default=81, help="Number of frames.", ) parser.add_argument( "--height", type=int, default=480, help="Image height.", ) parser.add_argument( "--width", type=int, default=832, help="Image width.", ) parser.add_argument( "--dataloader_num_workers", type=int, default=4, help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--learning_rate", type=float, default=1e-5, help="Learning rate.", ) parser.add_argument( "--accumulate_grad_batches", type=int, default=1, help="The number of batches in gradient accumulation.", ) parser.add_argument( "--max_epochs", type=int, default=1, help="Number of epochs.", ) parser.add_argument( "--training_strategy", type=str, default="deepspeed_stage_1", choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], help="Training strategy", ) parser.add_argument( "--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.", ) parser.add_argument( "--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to use gradient checkpointing offload.", ) parser.add_argument( "--use_swanlab", default=False, action="store_true", help="Whether to use SwanLab logger.", ) parser.add_argument( "--swanlab_mode", default=None, help="SwanLab mode (cloud or local).", ) parser.add_argument( "--metadata_file_name", type=str, default="metadata.csv", ) parser.add_argument( "--resume_ckpt_path", type=str, default=None, ) args = parser.parse_args() return args def data_process(args): dataset = TextVideoDataset( args.dataset_path, os.path.join(args.dataset_path, args.metadata_file_name), max_num_frames=args.num_frames, frame_interval=1, num_frames=args.num_frames, height=args.height, width=args.width, is_i2v=args.image_encoder_path is not None ) dataloader = torch.utils.data.DataLoader( dataset, shuffle=False, batch_size=1, num_workers=args.dataloader_num_workers ) model = LightningModelForDataProcess( text_encoder_path=args.text_encoder_path, image_encoder_path=args.image_encoder_path, vae_path=args.vae_path, tiled=args.tiled, tile_size=(args.tile_size_height, args.tile_size_width), tile_stride=(args.tile_stride_height, args.tile_stride_width), ) trainer = pl.Trainer( accelerator="gpu", devices="auto", default_root_dir=args.output_path, ) trainer.test(model, dataloader) def train(args): dataset = TensorDataset( args.dataset_path, os.path.join(args.dataset_path, "metadata.csv"), steps_per_epoch=args.steps_per_epoch, ) dataloader = torch.utils.data.DataLoader( dataset, shuffle=True, batch_size=1, num_workers=args.dataloader_num_workers ) model = LightningModelForTrain( dit_path=args.dit_path, learning_rate=args.learning_rate, use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, resume_ckpt_path=args.resume_ckpt_path, ) if args.use_swanlab: from swanlab.integration.pytorch_lightning import SwanLabLogger swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config.update(vars(args)) swanlab_logger = SwanLabLogger( project="wan", name="wan", config=swanlab_config, mode=args.swanlab_mode, logdir=os.path.join(args.output_path, "swanlog"), ) logger = [swanlab_logger] else: logger = None trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", devices="auto", precision="bf16", strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], logger=logger, ) trainer.fit(model, dataloader) if __name__ == '__main__': args = parse_args() os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) if args.task == "data_process": data_process(args) elif args.task == "train": train(args)