AI Choreographer多模式内容创建模型训练基础设施

联合创作 · 2023-09-29 19:08

这个包包含 AI Choreographer 的模型实现和训练基础设施,包括 FACT 模型实现。


拉取代码



git clone https://github.com/liruilong940607/mint --recursive

注意这里 --recursive 很重要,因为它也会自动克隆子模块。 


安装依赖




conda create -n mint python=3.7
conda activate mint
conda install protobuf numpy
pip install tensorflow absl-py tensorflow-datasets librosa

sudo apt-get install libopenexr-dev
pip install --upgrade OpenEXR
pip install tensorflow-graphics tensorflow-graphics-gpu

git clone https://github.com/arogozhnikov/einops /tmp/einops
cd /tmp/einops/ && pip install . -U

git clone https://github.com/google/aistplusplus_api /tmp/aistplusplus_api
cd /tmp/aistplusplus_api && pip install -r requirements.txt && pip install . -U

注意如果遇到 numpy 的环境冲突,可以试试 pip install numpy==1.20


获取数据


数据在该网站


运行代码



  • 编译协议



protoc ./mint/protos/*.proto


  • 将数据集预处理为 tfrecord 




python tools/preprocessing.py \
--anno_dir="/mnt/data/aist_plusplus_final/" \
--audio_dir="/mnt/data/AIST/music/" \
--split=train
python tools/preprocessing.py \
--anno_dir="/mnt/data/aist_plusplus_final/" \
--audio_dir="/mnt/data/AIST/music/" \
--split=testval


  • 训练




python trainer.py --config_path ./configs/fact_v5_deeper_t10_cm12.config --model_dir ./checkpoints


  • 运行测试和评估 




# caching the generated motions (seed included) to `./outputs`
python evaluator.py --config_path ./configs/fact_v5_deeper_t10_cm12.config --model_dir ./checkpoints
# calculate FIDs
python tools/calculate_scores.py


 


 

浏览 26
点赞
评论
收藏
分享

手机扫一扫分享

编辑 分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

编辑 分享
举报