算法训练 | 使用JAX训练CLIP算法_支持fine-tuning
原创
©著作权归作者所有:来自51CTO博客作者极智视界的原创作品,请联系作者获取转载授权,否则将追究法律责任
- 面向 CLIP 算法训练场景,项目采用 JAX 来实现 CLIP 算法的训练,支持 fine-tuning。
- 项目细节 ==> 具体参见项目
README.md
# clone and install datacomp
# download data
python download_upstream.py \
--scale small --data_dir gs://my_bucket/datacomp/small metadata_dir metadata \
--image_size 256 --resize_mode center_crop --skip_bbox_blurring --no_resize_only_if_bigger \
--encode_format webp --output_format tfrecord
python train.py \
--assert_TPU_available \
--config_name ../configs/small-patch16.json --dtype float32 \
--do_train --train_folder gs://my_bucket/datacomp/small/shards \
--output_dir gs://my_bucket/clip_model/$(date +"%Y%m%d%H%M%S") \
--num_train_epochs 10 \
--tokenizer_name openai/clip-vit-base-patch32 \
--batch_size_per_node 4096 --gradient_accumulation_steps 1 \
--learning_rate 0.00001 --warmup_steps 2000 --lr_offset 0 \
--optim distributed_shampoo --beta1 0.9 --beta2 0.99 --weight_decay 0.0 \
--block_size_text 512 --block_size_vision 512 --nesterov \
--graft_type rmsprop_normalized --preconditioning_compute_steps 20 \
--mp_devices 1 --shard_shampoo_across 2d \
--activation_partitioning_dims 1 --parameter_partitioning_dims 1 \
--loss_type sigmoid \
--gradient_checkpointing \
--unroll 100 \
--logging_steps 100 --save_steps 5000