diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3cc8ad6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+checkpoints/waveunet/
+.cog
diff --git a/README.md b/README.md
index de8fd80..c2dc947 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
# Wave-U-Net (Pytorch)
-
+
Improved version of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation, implemented in Pytorch.
@@ -97,6 +97,10 @@ We provide the default model in a pre-trained form as download so you can separa
Download our pretrained model [here](https://www.dropbox.com/s/r374hce896g4xlj/models.7z?dl=1).
Extract the archive into the ``checkpoints`` subfolder in this repository, so that you have one subfolder for each model (e.g. ``REPO/checkpoints/waveunet``)
+If you have Docker installed, you can run this script to download the weights from [Replicate](https://replicate.com/f90/wave-u-net-pytorch):
+
+ $ script/download-weights
+
## Run pretrained model
To apply our pretrained model to any of your own songs, simply point to its audio file path using the ``input_path`` parameter:
diff --git a/cog.yaml b/cog.yaml
index 83c6d2b..4e2bbf3 100644
--- a/cog.yaml
+++ b/cog.yaml
@@ -1,5 +1,5 @@
build:
- python_version: "3.6"
+ python_version: "3.7"
gpu: false
python_packages:
- future==0.18.2
@@ -17,4 +17,4 @@ build:
system_packages:
- libsndfile-dev
- ffmpeg
-predict: "cog_predict.py:waveunetPredictor"
+predict: "cog_predict.py:Predictor"
diff --git a/cog_predict.py b/cog_predict.py
index 483d3b4..a4b6fd9 100644
--- a/cog_predict.py
+++ b/cog_predict.py
@@ -1,16 +1,24 @@
+import argparse
import os
-import cog
import tempfile
import zipfile
-from pathlib import Path
-import argparse
+
+from cog import BasePredictor, Input, Path, BaseModel
+
import data.utils
import model.utils as model_utils
-from test import predict_song
from model.waveunet import Waveunet
+from test import predict_song
+
+
+class Output(BaseModel):
+ bass: Path
+ drums: Path
+ other: Path
+ vocals: Path
-class waveunetPredictor(cog.Predictor):
+class Predictor(BasePredictor):
def setup(self):
"""Init wave u net model"""
parser = argparse.ArgumentParser()
@@ -112,7 +120,7 @@ def setup(self):
)
if args.cuda:
- self.model = model_utils.DataParallel(model)
+ self.model = model_utils.DataParallel(self.model)
print("move model to gpu")
self.model.cuda()
@@ -120,25 +128,16 @@ def setup(self):
state = model_utils.load_model(self.model, None, args.load_model, args.cuda)
print("Step", state["step"])
- @cog.input("input", type=Path, help="audio mixture path")
- def predict(self, input):
+ def predict(self, mix: Path = Input(description="audio mixture path")) -> Output:
"""Separate tracks from input mixture audio"""
- out_path = Path(tempfile.mkdtemp())
- zip_path = Path(tempfile.mkdtemp()) / "output.zip"
+ tmpdir = Path(tempfile.mkdtemp())
- preds = predict_song(self.args, input, self.model)
-
- out_names = []
+ preds = predict_song(self.args, mix, self.model)
+ output = {}
for inst in preds.keys():
- temp_n = os.path.join(
- str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav"
- )
- data.utils.write_wav(temp_n, preds[inst], self.args.sr)
- out_names.append(temp_n)
-
- with zipfile.ZipFile(str(zip_path), "w") as zf:
- for i in out_names:
- zf.write(str(i))
+ path = tmpdir / (inst + ".wav")
+ data.utils.write_wav(path, preds[inst], self.args.sr)
+ output[inst] = path
- return zip_path
+ return Output(**output)
diff --git a/script/download-weights b/script/download-weights
new file mode 100755
index 0000000..4a10162
--- /dev/null
+++ b/script/download-weights
@@ -0,0 +1,3 @@
+#!/bin/bash
+id=$(docker create r8.im/f90/wave-u-net-pytorch)
+docker cp $id:/src/checkpoints ./