Skip to content

Commit

Permalink
Improved WaveParser, Added support for any bitspersample + fixed sand…
Browse files Browse the repository at this point in the history
  • Loading branch information
sandrohanea committed Jul 1, 2023
1 parent 5124598 commit 441433d
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 101 deletions.
1 change: 0 additions & 1 deletion Whisper.net.Tests/FactoryTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT
using FluentAssertions;
using NUnit.Framework;
using Whisper.net.Ggml;

namespace Whisper.net.Tests;

Expand Down
12 changes: 0 additions & 12 deletions Whisper.net/GlobalSuppressions.cs

This file was deleted.

235 changes: 147 additions & 88 deletions Whisper.net/Wave/WaveParser.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Runtime.CompilerServices;

namespace Whisper.net.Wave;

public sealed class WaveParser
Expand Down Expand Up @@ -80,45 +82,36 @@ public WaveParser(Stream waveStream)
public int FrameSize => bitsPerSample / 8 * channels;

/// <summary>
/// Returns the average samples from all channels.
/// Gets the value to divide the sample by to get the actual float value.
/// </summary>
public async Task<float[]> GetAvgSamplesAsync(CancellationToken cancellationToken)
public float ValueToDivide => bitsPerSample switch
{
if (!isInitialized)
{
await InitializeAsync();
}
8 => 128.0f,
16 => 32768.0f,
24 => 8388608.0f,
_ => 2147483648.0f
};

if (channels == 0)
{
throw new InvalidOperationException("Channel count is set to 0");
}
/// <summary>
/// Returns the average samples from all channels.
/// </summary>
public async Task<float[]> GetAvgSamplesAsync(CancellationToken cancellationToken = default)
{
await InitializeAsync(cancellationToken);

var samples = new float[SamplesCount];

var buffer = new byte[2048 * channels];

var sampleIndex = 0;
int bytesRead;

do
await foreach (var sampleFrame in InternalReadSamples(useAsync: true, cancellationToken))
{
bytesRead = await waveStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken);

for (var i = 0; i < bytesRead;)
var sampleSum = 0L;
for (var i = 0; i < sampleFrame.Length; i++)
{
long sampleSum = 0;

for (var currentChannel = 0; currentChannel < channels; currentChannel++)
{
sampleSum += BitConverter.ToInt16(buffer, i);
i += 2;
}

samples[sampleIndex++] = sampleSum / (float)channels / 32768.0f;
sampleSum += sampleFrame[i];
}
} while (bytesRead > 0);

samples[sampleIndex++] = (sampleSum / ValueToDivide) / channels;
}
return samples;
}

Expand All @@ -128,52 +121,65 @@ public async Task<float[]> GetAvgSamplesAsync(CancellationToken cancellationToke
/// <returns></returns>
public float[] GetAvgSamples()
{
if (!isInitialized)
{
Initialize();
}
Initialize();

var reader = GetDataReader();
var asyncEnumerator = InternalReadSamples(useAsync: false, CancellationToken.None).GetAsyncEnumerator();
var samples = new float[SamplesCount];

for (var i = 0; i < SamplesCount; i++)
var sampleIndex = 0;
// Will disable CA2012 as I took care of the async enumerator to always run synchronous if useAsync is false.
#pragma warning disable CA2012 // Use ValueTasks correctly
while (asyncEnumerator.MoveNextAsync().GetAwaiter().GetResult())
{
var sampleFrame = asyncEnumerator.Current;
var sampleSum = 0L;

for (var currentChannel = 0; currentChannel < channels; currentChannel++)
for (var i = 0; i < sampleFrame.Length; i++)
{
sampleSum += reader.ReadInt16();
sampleSum += sampleFrame[i];
}

samples[i] = (sampleSum / 32768.0f) / channels;
samples[sampleIndex++] = (sampleSum / ValueToDivide) / channels;
}

#pragma warning restore CA2012 // Use ValueTasks correctly
return samples;
}

public float[] GetChannelSamples(int channelIndex = 0)
{
if (!isInitialized)
Initialize();
if (channelIndex >= channels)
{
Initialize();
throw new ArgumentOutOfRangeException(nameof(channelIndex));
}

var reader = GetDataReader();
var samples = new float[SamplesCount];
var sampleIndex = 0;

for (var i = 0; i < SamplesCount; i++)
var asyncEnumerator = InternalReadSamples(useAsync: false, CancellationToken.None).GetAsyncEnumerator();
// Will disable CA2012 as I took care of the async enumerator to always run synchronous if useAsync is false.
#pragma warning disable CA2012 // Use ValueTasks correctly
while (asyncEnumerator.MoveNextAsync().GetAwaiter().GetResult())
{
for (var currentChannel = 0; currentChannel < channels; currentChannel++)
{
if (channelIndex == currentChannel)
{
samples[i] = reader.ReadInt16() / 32768.0f;
}
else
{
_ = reader.ReadInt16();
}
}
var sampleFrame = asyncEnumerator.Current;
samples[sampleIndex++] = sampleFrame[channelIndex] / ValueToDivide;
}
#pragma warning restore CA2012 // Use ValueTasks correctly
return samples;
}

public async Task<float[]> GetChannelSamplesAsync(int channelIndex = 0, CancellationToken cancellationToken = default)
{
await InitializeAsync(cancellationToken);
if (channelIndex >= channels)
{
throw new ArgumentOutOfRangeException(nameof(channelIndex));
}

var samples = new float[SamplesCount];
var sampleIndex = 0;

await foreach (var sampleFrame in InternalReadSamples(useAsync: true, cancellationToken))
{
samples[sampleIndex++] = sampleFrame[channelIndex] / ValueToDivide;
}
return samples;
}
Expand All @@ -183,28 +189,92 @@ public float[] GetChannelSamples(int channelIndex = 0)
/// </summary>
public void Initialize()
{
InitializeCore(useAsync: false).GetAwaiter().GetResult();
InternalInitialize(useAsync: false, CancellationToken.None).GetAwaiter().GetResult();
}

/// <summary>
/// Initializes the wave parser, by reading the header and the format chunk in an async manner.
/// </summary>
public Task InitializeAsync()
public Task InitializeAsync(CancellationToken cancellationToken = default)
{
return InternalInitialize(useAsync: true, cancellationToken);
}

private async IAsyncEnumerable<long[]> InternalReadSamples(bool useAsync, [EnumeratorCancellation] CancellationToken cancellationToken)
{
return InitializeCore(useAsync: true);
var buffer = new byte[2048 * channels];
var memoryBuffer = buffer.AsMemory();

var sampleIndex = 0;
var bytesRead = int.MaxValue;

while (bytesRead > 0 && sampleIndex < SamplesCount)
{
// We need to ensure that we don't read from the stream, more data than the data filled by samples count.
var maxBytesToRead = (int)Math.Min(buffer.Length, (SamplesCount - sampleIndex) * FrameSize);
if (useAsync)
{
#if NET6_0_OR_GREATER
var memoryToUse = maxBytesToRead == buffer.Length ? memoryBuffer : memoryBuffer[..maxBytesToRead];
bytesRead = await waveStream.ReadAsync(memoryToUse, cancellationToken);
#else
bytesRead = await waveStream.ReadAsync(buffer, 0, maxBytesToRead, cancellationToken);
#endif
}
else
{
bytesRead = waveStream.Read(buffer, 0, maxBytesToRead);
}

for (var i = 0; i < bytesRead;)
{
var currentSamples = new long[channels];

for (var currentChannel = 0; currentChannel < channels; currentChannel++)
{
var (currentChannelValue, bytesConsumed) = bitsPerSample switch
{
8 => (buffer[i] - 128, 1),
16 => (BitConverter.ToInt16(buffer, i), 2),
24 => (BitConverter.ToInt32(buffer, i) >> 8, 3),
_ => (BitConverter.ToInt32(buffer, i), 4),
};
currentSamples[currentChannel] = currentChannelValue;
i += bytesConsumed;
}
yield return currentSamples;
sampleIndex++;
}
};
if (sampleIndex < SamplesCount)
{
throw new CorruptedWaveException("Invalid wave file, the size is too small and couldn't read all the samples.");
}
}

private async Task InitializeCore(bool useAsync)
private async Task InternalInitialize(bool useAsync, CancellationToken cancellationToken)
{
if (isInitialized)
{
return;
}

async Task<int> ReadBytesAsync(byte[] buffer, int offset, int count)
{
if (useAsync)
{
#if NET6_0_OR_GREATER
return await waveStream.ReadAsync(buffer.AsMemory(offset, count), cancellationToken);
#else
return await waveStream.ReadAsync(buffer, offset, count, cancellationToken);
#endif
}

return waveStream.Read(buffer, offset, count);
}

var buffer = new byte[12];
var actualRead = useAsync
? await waveStream.ReadAsync(buffer, 0, 12)
: waveStream.Read(buffer, 0, 12);
var actualRead = await ReadBytesAsync(buffer, 0, 12);

if (actualRead != 12)
{
Expand All @@ -230,9 +300,8 @@ private async Task InitializeCore(bool useAsync)
while (true)
{
var nextChunkHeader = new byte[8];
actualRead = useAsync
? await waveStream.ReadAsync(nextChunkHeader, 0, 8)
: waveStream.Read(nextChunkHeader, 0, 8);

actualRead = await ReadBytesAsync(nextChunkHeader, 0, 8);

if (actualRead != 8)
{
Expand All @@ -258,8 +327,7 @@ private async Task InitializeCore(bool useAsync)
else
{
var restOfChunk = new byte[chunkSize];
_ = useAsync ? await waveStream.ReadAsync(restOfChunk, 0, chunkSize)
: waveStream.Read(restOfChunk, 0, chunkSize);
await ReadBytesAsync(restOfChunk, 0, chunkSize);
}
}

Expand All @@ -269,8 +337,7 @@ private async Task InitializeCore(bool useAsync)
}

var fmtBuffer = new byte[fmtChunkSize];
actualRead = useAsync ? await waveStream.ReadAsync(fmtBuffer, 0, fmtChunkSize)
: waveStream.Read(fmtBuffer, 0, fmtChunkSize);
actualRead = await ReadBytesAsync(fmtBuffer, 0, fmtChunkSize);
if (actualRead != fmtChunkSize)
{
throw new CorruptedWaveException("Invalid wave file, cannot read format chunk.");
Expand Down Expand Up @@ -304,6 +371,11 @@ private async Task InitializeCore(bool useAsync)
}

channels = BitConverter.ToUInt16(fmtBuffer, 2);
if (channels == 0)
{
throw new NotSupportedWaveException("Cannot read wave file with 0 channels.");
}

sampleRate = BitConverter.ToUInt32(fmtBuffer, 4);
if (sampleRate != 16000)
{
Expand All @@ -315,16 +387,16 @@ private async Task InitializeCore(bool useAsync)
// Skip Block Allign 12 => 14

bitsPerSample = BitConverter.ToUInt16(fmtBuffer, 14);
if (bitsPerSample != 16)

if (bitsPerSample != 8 && bitsPerSample != 16 && bitsPerSample != 24 && bitsPerSample != 32)
{
throw new NotSupportedWaveException("Only 16 bits per sample is supported.");
throw new NotSupportedWaveException($"Bits per sample {bitsPerSample} is not supported.");
}

// Seek data chuunk
// Read chunk name and size

_ = useAsync ? await waveStream.ReadAsync(buffer, 0, 8)
: waveStream.Read(buffer, 0, 8);
await ReadBytesAsync(buffer, 0, 8);

while (buffer[0] != 'd' || buffer[1] != 'a' || buffer[2] != 't' || buffer[3] != 'a')
{
Expand All @@ -340,12 +412,10 @@ private async Task InitializeCore(bool useAsync)
else
{
var restOfChunk = new byte[chunkSize];
_ = useAsync ? await waveStream.ReadAsync(restOfChunk, 0, chunkSize)
: waveStream.Read(restOfChunk, 0, chunkSize);
await ReadBytesAsync(restOfChunk, 0, chunkSize);
}

actualRead = useAsync ? await waveStream.ReadAsync(buffer, 0, 8)
: waveStream.Read(buffer, 0, 8);
actualRead = await ReadBytesAsync(buffer, 0, 8);

if (actualRead != 8)
{
Expand All @@ -364,15 +434,4 @@ private async Task InitializeCore(bool useAsync)
dataChunkPosition = waveStream.Position;
isInitialized = true;
}

private BinaryReader GetDataReader()
{
var reader = new BinaryReader(waveStream);
if (waveStream.Position != dataChunkPosition)
{
waveStream.Seek(dataChunkPosition, SeekOrigin.Begin);
}
return reader;
}

}

0 comments on commit 441433d

Please sign in to comment.