diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3cc8ad6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +checkpoints/waveunet/ +.cog diff --git a/README.md b/README.md index de8fd80..c2dc947 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Wave-U-Net (Pytorch) - + Improved version of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation, implemented in Pytorch. @@ -97,6 +97,10 @@ We provide the default model in a pre-trained form as download so you can separa Download our pretrained model [here](https://www.dropbox.com/s/r374hce896g4xlj/models.7z?dl=1). Extract the archive into the ``checkpoints`` subfolder in this repository, so that you have one subfolder for each model (e.g. ``REPO/checkpoints/waveunet``) +If you have Docker installed, you can run this script to download the weights from [Replicate](https://replicate.com/f90/wave-u-net-pytorch): + + $ script/download-weights + ## Run pretrained model To apply our pretrained model to any of your own songs, simply point to its audio file path using the ``input_path`` parameter: diff --git a/cog.yaml b/cog.yaml index 83c6d2b..4e2bbf3 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,5 +1,5 @@ build: - python_version: "3.6" + python_version: "3.7" gpu: false python_packages: - future==0.18.2 @@ -17,4 +17,4 @@ build: system_packages: - libsndfile-dev - ffmpeg -predict: "cog_predict.py:waveunetPredictor" +predict: "cog_predict.py:Predictor" diff --git a/cog_predict.py b/cog_predict.py index 483d3b4..a4b6fd9 100644 --- a/cog_predict.py +++ b/cog_predict.py @@ -1,16 +1,24 @@ +import argparse import os -import cog import tempfile import zipfile -from pathlib import Path -import argparse + +from cog import BasePredictor, Input, Path, BaseModel + import data.utils import model.utils as model_utils -from test import predict_song from model.waveunet import Waveunet +from test import predict_song + + +class Output(BaseModel): + bass: Path + drums: Path + other: Path + vocals: Path -class waveunetPredictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): """Init wave u net model""" parser = argparse.ArgumentParser() @@ -112,7 +120,7 @@ def setup(self): ) if args.cuda: - self.model = model_utils.DataParallel(model) + self.model = model_utils.DataParallel(self.model) print("move model to gpu") self.model.cuda() @@ -120,25 +128,16 @@ def setup(self): state = model_utils.load_model(self.model, None, args.load_model, args.cuda) print("Step", state["step"]) - @cog.input("input", type=Path, help="audio mixture path") - def predict(self, input): + def predict(self, mix: Path = Input(description="audio mixture path")) -> Output: """Separate tracks from input mixture audio""" - out_path = Path(tempfile.mkdtemp()) - zip_path = Path(tempfile.mkdtemp()) / "output.zip" + tmpdir = Path(tempfile.mkdtemp()) - preds = predict_song(self.args, input, self.model) - - out_names = [] + preds = predict_song(self.args, mix, self.model) + output = {} for inst in preds.keys(): - temp_n = os.path.join( - str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav" - ) - data.utils.write_wav(temp_n, preds[inst], self.args.sr) - out_names.append(temp_n) - - with zipfile.ZipFile(str(zip_path), "w") as zf: - for i in out_names: - zf.write(str(i)) + path = tmpdir / (inst + ".wav") + data.utils.write_wav(path, preds[inst], self.args.sr) + output[inst] = path - return zip_path + return Output(**output) diff --git a/script/download-weights b/script/download-weights new file mode 100755 index 0000000..4a10162 --- /dev/null +++ b/script/download-weights @@ -0,0 +1,3 @@ +#!/bin/bash +id=$(docker create r8.im/f90/wave-u-net-pytorch) +docker cp $id:/src/checkpoints ./