Skip to content

Commit

Permalink
Fix new rust feature issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
twitchax committed Dec 8, 2023
1 parent 9de1c23 commit fa9bb97
Show file tree
Hide file tree
Showing 17 changed files with 41 additions and 45 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: test
args: --features ml
codecov:
needs: test
name: Code Coverage
Expand Down
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ wasm = ["rodio/wasm-bindgen", "wasm-bindgen", "wasm-bindgen-futures", "js-sys",
plot = ["plotters"]

[dependencies]
proc-macro2 = "1.0.60"
async-trait = "0.1.64"
once_cell = "1.16.0"
paste = "1.0.9"
Expand Down
4 changes: 2 additions & 2 deletions src/analyze/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ pub fn get_audio_data_from_file(file: impl AsRef<Path>, start: Option<Duration>,
}

/// Play the given segment of an audio file. Used to preview a clip before guessing notes from it.
#[no_coverage]
#[coverage(off)]
pub fn preview_audio_file_clip(file: impl AsRef<Path>, start: Option<Duration>, end: Option<Duration>) -> Res<()> {
let file = File::open(file)?;
preview_audio_clip(file, start, end)
}

/// Play the given segment of an audio stream. Used to preview a clip before guessing notes from it.
#[no_coverage]
#[coverage(off)]

Check failure on line 53 in src/analyze/file.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 53 in src/analyze/file.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 53 in src/analyze/file.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
pub fn preview_audio_clip(stream: impl Read + Seek + Send + Sync + 'static, start: Option<Duration>, end: Option<Duration>) -> Res<()> {
let start = start.unwrap_or_default();
let decoder = Decoder::new(stream)?.skip_duration(start).convert_samples();
Expand Down
8 changes: 4 additions & 4 deletions src/analyze/mic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::core::{base::Res, note::Note};
use super::base::get_notes_from_audio_data;

/// Gets notes from the microphone input over the specified period of time.
#[no_coverage]
#[coverage(off)]
pub async fn get_notes_from_microphone(length_in_seconds: u8) -> Res<Vec<Note>> {
// Get data.

Expand All @@ -30,7 +30,7 @@ pub async fn get_notes_from_microphone(length_in_seconds: u8) -> Res<Vec<Note>>
}

/// Gets audio data from the microphone.
#[no_coverage]
#[coverage(off)]
pub async fn get_audio_data_from_microphone(length_in_seconds: u8) -> Res<Vec<f32>> {
if length_in_seconds < 1 {
return Err(anyhow::Error::msg("Listening length in seconds must be greater than 1."));
Expand All @@ -48,7 +48,7 @@ pub async fn get_audio_data_from_microphone(length_in_seconds: u8) -> Res<Vec<f3
}

/// Gets the system device, and config.
#[no_coverage]
#[coverage(off)]
fn get_device_and_config() -> Res<(cpal::Device, cpal::SupportedStreamConfig)> {
let host = cpal::default_host();

Expand All @@ -60,7 +60,7 @@ fn get_device_and_config() -> Res<(cpal::Device, cpal::SupportedStreamConfig)> {
}

/// Records audio data from the device.
#[no_coverage]
#[coverage(off)]
async fn record_from_device(device: cpal::Device, config: cpal::SupportedStreamConfig, length_in_seconds: u8) -> Res<Vec<f32>> {
// Set up recording.

Expand Down
2 changes: 1 addition & 1 deletion src/core/chord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ use super::base::{Playable, PlaybackHandle};

#[cfg(feature = "audio")]
impl Playable for Chord {
#[no_coverage]
#[coverage(off)]

Check failure on line 1233 in src/core/chord.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 1233 in src/core/chord.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 1233 in src/core/chord.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn play(&self, delay: Duration, length: Duration, fade_in: Duration) -> Res<PlaybackHandle> {
use rodio::{source::SineWave, OutputStream, Sink, Source};

Expand Down
6 changes: 3 additions & 3 deletions src/core/modifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl HasIsDominant for Modifier {
}

impl HasStaticName for Degree {
#[no_coverage]
#[coverage(off)]

Check failure on line 120 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 120 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 120 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn static_name(&self) -> &'static str {
match self {
Degree::Seven => "7",
Expand All @@ -129,7 +129,7 @@ impl HasStaticName for Degree {
}

impl HasStaticName for Modifier {
#[no_coverage]
#[coverage(off)]

Check failure on line 132 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 132 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 132 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn static_name(&self) -> &'static str {
match self {
Modifier::Minor => "m",
Expand All @@ -151,7 +151,7 @@ impl HasStaticName for Modifier {
}

impl HasStaticName for Extension {
#[no_coverage]
#[coverage(off)]

Check failure on line 154 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 154 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 154 in src/core/modifier.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn static_name(&self) -> &'static str {
match self {
Extension::Sus2 => "sus2",
Expand Down
6 changes: 3 additions & 3 deletions src/core/named_pitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl HasNamedPitch for NamedPitch {
}

impl HasLetter for NamedPitch {
#[no_coverage]
#[coverage(off)]

Check failure on line 153 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 153 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 153 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn letter(&self) -> &'static str {
match self {
NamedPitch::FTripleFlat => "F",
Expand Down Expand Up @@ -213,7 +213,7 @@ impl HasLetter for NamedPitch {
}

impl HasStaticName for NamedPitch {
#[no_coverage]
#[coverage(off)]

Check failure on line 216 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 216 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 216 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn static_name(&self) -> &'static str {
match self {
NamedPitch::FTripleFlat => "F♭𝄫",
Expand Down Expand Up @@ -276,7 +276,7 @@ impl HasStaticName for NamedPitch {
}

impl HasPitch for NamedPitch {
#[no_coverage]
#[coverage(off)]

Check failure on line 279 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 279 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 279 in src/core/named_pitch.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn pitch(&self) -> Pitch {
match self {
NamedPitch::FTripleFlat => Pitch::D,
Expand Down
4 changes: 2 additions & 2 deletions src/core/note.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl Note {
/// to identify the notes in the recorded audio.
///
/// Currently, this does not work with WASM.
#[no_coverage]
#[coverage(off)]

Check failure on line 194 in src/core/note.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 194 in src/core/note.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 194 in src/core/note.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
#[cfg(feature = "analyze_mic")]
pub async fn try_from_mic(length_in_seconds: u8) -> Res<Vec<Note>> {
use crate::analyze::mic::get_notes_from_microphone;
Expand All @@ -211,7 +211,7 @@ impl Note {
/// to identify the notes in the recorded audio using ML.
///
/// Currently, this does not work with WASM.
#[no_coverage]
#[coverage(off)]
#[cfg(all(feature = "ml_infer", feature = "analyze_mic"))]
pub async fn try_from_mic_ml(length_in_seconds: u8) -> Res<Vec<Self>> {
use crate::{analyze::mic::get_audio_data_from_microphone, ml::infer::infer};
Expand Down
4 changes: 2 additions & 2 deletions src/core/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct ChordParser;
// Helpers.

/// Parses a [`Note`] [`str`] into a [`Note`].
#[no_coverage]
#[coverage(off)]
pub fn note_str_to_note(note_str: &str) -> Res<Note> {
let chord = match note_str {
"A" => note::A,
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn note_str_to_note(note_str: &str) -> Res<Note> {
}

/// Parses an [`Octave`] [`str`] into an [`Octave`].
#[no_coverage]
#[coverage(off)]
pub fn octave_str_to_octave(note_str: &str) -> Res<Octave> {
let octave = match note_str {
"0" => Octave::Zero,
Expand Down
2 changes: 1 addition & 1 deletion src/core/pitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub enum Pitch {
// Pitch impls.

impl HasBaseFrequency for Pitch {
#[no_coverage]
#[coverage(off)]

Check failure on line 96 in src/core/pitch.rs

View workflow job for this annotation

GitHub Actions / Build Linux

cannot find attribute `coverage` in this scope

Check failure on line 96 in src/core/pitch.rs

View workflow job for this annotation

GitHub Actions / Build Windows

cannot find attribute `coverage` in this scope

Check failure on line 96 in src/core/pitch.rs

View workflow job for this annotation

GitHub Actions / Build MacOS

cannot find attribute `coverage` in this scope
fn base_frequency(&self) -> f32 {
match self {
Pitch::C => 16.35,
Expand Down
2 changes: 0 additions & 2 deletions src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//! Global helper functions.
/// Plot the frequency space of the microphone input using plotters.
#[cfg(feature = "plot")]
pub fn plot_frequency_space(frequency_space: &[(f32, f32)], title: &str, file_name: &str, x_min: f32, x_max: f32) {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
#![feature(specialization)]
#![feature(concat_idents)]
#![feature(iter_advance_by)]
#![feature(no_coverage)]
#![feature(int_roundings)]
#![feature(coverage_attribute)]

pub mod core;
pub mod helpers;

Check warning on line 53 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a module
Expand Down
2 changes: 1 addition & 1 deletion src/ml/base/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::analyze::mic::get_audio_data_from_microphone;
use super::helpers::save_kord_item;

/// Gather a sample from the microphone and save it to disk.
#[no_coverage]
#[coverage(off)]
pub fn gather_sample(destination: impl AsRef<Path>, length_in_seconds: u8) -> Void {
println!("Listening ...");

Expand Down
18 changes: 9 additions & 9 deletions src/ml/base/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use burn::tensor::{backend::Backend, Tensor};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};

use crate::{
analyze::base::get_notes_from_smoothed_frequency_space,
analyze::base::{get_notes_from_smoothed_frequency_space},
core::{
base::Res,
helpers::{inv_mel, mel},
Expand Down Expand Up @@ -111,7 +111,13 @@ pub fn mel_filter_banks_from(spectrum: &[f32]) -> [f32; MEL_SPACE_SIZE] {
pub fn harmonic_convolution(spectrum: &[f32]) -> [f32; FREQUENCY_SPACE_SIZE] {
let mut harmonic_convolution = [0f32; FREQUENCY_SPACE_SIZE];

let (peak, _) = spectrum.iter().enumerate().fold((0usize, 0f32), |(k, max), (j, x)| if *x > max { (j, *x) } else { (k, max) });
let (peak, _) = spectrum.iter().enumerate().fold((0usize, 0f32), |(k, max), (j, x)| {
if *x > max {
(j, *x)
} else {
(k, max)
}
});

for center in (peak / 2)..4000 {
let mut sum = spectrum[center];
Expand Down Expand Up @@ -171,9 +177,6 @@ pub fn binary_to_u128(binary: &[f32]) -> u128 {
num
}

/// Folds a 128 element array of 0s and 1s into a 12 element array of 0s and 1s.
///
/// Essentially, this is useful if we want to do inference on the pitches without octaves.
#[allow(dead_code)]
pub fn fold_binary(binary: &[f32; 128]) -> [f32; 12] {

Check warning on line 181 in src/ml/base/helpers.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a function
let mut folded = [0f32; 12];
Expand All @@ -191,21 +194,18 @@ pub fn fold_binary(binary: &[f32; 128]) -> [f32; 12] {

// Common tensor operations.

/// A sigmoid activation function.
#[derive(Debug, Clone)]
pub struct Sigmoid {

Check warning on line 198 in src/ml/base/helpers.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a struct
scale: f32,
}

impl Sigmoid {
/// Create a new sigmoid activation function.
pub fn new(scale: f32) -> Self {

Check warning on line 203 in src/ml/base/helpers.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for an associated function
Self { scale }
}

/// Forward pass.
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {

Check warning on line 207 in src/ml/base/helpers.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a method
let scaled = input.mul_scalar(self.scale);
scaled.clone().exp().div(scaled.exp().add_scalar(1.0))
}
}
}
15 changes: 5 additions & 10 deletions src/ml/base/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::ml::train::{
helpers::{KordClassificationOutput, MeanSquareLoss},
};

/// The primary model type for identifying notes / chords.
#[derive(Module, Debug)]
pub struct KordModel<B: Backend> {

Check warning on line 20 in src/ml/base/model.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a struct
input: Param<nn::Linear<B>>,
Expand All @@ -26,7 +25,6 @@ pub struct KordModel<B: Backend> {
}

impl<B: Backend> KordModel<B> {
/// Create a new model with the given parameters.
pub fn new(mlp_layers: usize, mlp_size: usize, mlp_dropout: f64, sigmoid_strength: f32) -> Self {

Check warning on line 28 in src/ml/base/model.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for an associated function
let input = nn::Linear::new(&nn::LinearConfig::new(INPUT_SPACE_SIZE, mlp_size));
let mlp = Mlp::new(mlp_layers, mlp_size, mlp_dropout);
Expand All @@ -41,7 +39,6 @@ impl<B: Backend> KordModel<B> {
}
}

/// Forward pass through the model.
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {

Check warning on line 42 in src/ml/base/model.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a method
let mut x = input;

Expand All @@ -54,8 +51,8 @@ impl<B: Backend> KordModel<B> {
}

#[cfg(feature = "ml_train")]
/// Forward pass through the model, with loss calculation.
pub fn forward_classification(&self, item: KordBatch<B>) -> KordClassificationOutput<B> {

Check warning on line 54 in src/ml/base/model.rs

View workflow job for this annotation

GitHub Actions / Test

missing documentation for a method

let targets = item.targets;
let output = self.forward(item.samples);

Expand All @@ -80,28 +77,26 @@ impl<B: Backend> KordModel<B> {
}
}

/// A convolutional block.
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Param<nn::conv::Conv1d<B>>,
activation: nn::ReLU,
}

/// A convolutional block.
impl<B: Backend> ConvBlock<B> {
/// Create a new convolutional block.
pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
let conv = nn::conv::Conv1d::new(&nn::conv::Conv1dConfig::new(in_channels, out_channels, kernel_size).with_bias(false));
let conv = nn::conv::Conv1d::new(
&nn::conv::Conv1dConfig::new(in_channels, out_channels, kernel_size).with_bias(false),
);

Self {
conv: Param::from(conv),
activation: nn::ReLU::new(),
}
}

/// Forward pass through the block.
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv.forward(input);
self.activation.forward(x)
}
}
}
4 changes: 2 additions & 2 deletions src/ml/train/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ where
Ok(accuracy)
}

#[no_coverage]
#[coverage(off)]
pub fn compute_overall_accuracy<B: Backend>(model_trained: &KordModel<B>, device: &B::Device) -> f32 {
let dataset = KordDataset::from_folder_and_simulation("samples", 0, 0.0, 0.0, 0.0);

Expand Down Expand Up @@ -192,7 +192,7 @@ pub fn compute_overall_accuracy<B: Backend>(model_trained: &KordModel<B>, device
inference_accuracy
}

#[no_coverage]
#[coverage(off)]
pub fn hyper_parameter_tuning(source: String, destination: String, log: String, device: String) -> Void {
let peak_radiuses = [1.0];
let harmonic_decays = [0.1];
Expand Down

0 comments on commit fa9bb97

Please sign in to comment.