Skip to content

Commit

Permalink
Added multichannel support.
Browse files Browse the repository at this point in the history
  • Loading branch information
sandrohanea committed May 20, 2023
1 parent 47e4cfd commit ad161d4
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 159 deletions.
35 changes: 35 additions & 0 deletions Whisper.net.Tests/ProcessorE2ETests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,39 @@ public async Task ProcessAsync_WhenJunkChunkExists_ProcessCorrectly()
segments.Should().HaveCountGreaterThanOrEqualTo(1);
}

[Test]
public async Task ProcessAsync_WhenMultichannel_ProcessCorrectly()
{
var segments = new List<SegmentData>();

using var factory = WhisperFactory.FromPath(ggmlModelPath);
await using var processor = factory.CreateBuilder()
.WithLanguage("en")
.Build();

using var fileReader = File.OpenRead("multichannel.wav");
await foreach (var segment in processor.ProcessAsync(fileReader))
{
segments.Add(segment);
}

segments.Should().HaveCountGreaterThanOrEqualTo(1);
}

[Test]
public async Task Process_WhenMultichannel_ProcessCorrectly()
{
var segments = new List<SegmentData>();

using var factory = WhisperFactory.FromPath(ggmlModelPath);
await using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("multichannel.wav");
processor.Process(fileReader);

segments.Should().HaveCountGreaterThanOrEqualTo(1);
}
}
3 changes: 3 additions & 0 deletions Whisper.net.Tests/Whisper.net.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
<None Update="kennedy.wav">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="multichannel.wav">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="romana.wav">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
Expand Down
Binary file added Whisper.net.Tests/multichannel.wav
Binary file not shown.
4 changes: 2 additions & 2 deletions Whisper.net.sln
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{2C8F58A8-37D9-4E96-AE73-B3B37AEC10C1}"
ProjectSection(SolutionItems) = preProject
.editorconfig = .editorconfig
..\readme.md = ..\readme.md
readme.md = readme.md
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Whisper.net", "Whisper.net\Whisper.net.csproj", "{5D1F57EB-AF79-419C-8DCB-E23F7BEC3E8C}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Whisper.net.Tests", "Whisper.net.Tests\Whisper.net.Tests.csproj", "{333764F7-81DC-4425-AF5C-B438D973EE13}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Whisper.net.Tests", "Whisper.net.Tests\Whisper.net.Tests.csproj", "{333764F7-81DC-4425-AF5C-B438D973EE13}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "github", "github", "{5C865649-4AED-4BBF-A7CA-BD68D21EF48F}"
ProjectSection(SolutionItems) = preProject
Expand Down
203 changes: 46 additions & 157 deletions Whisper.net/Wave/WaveParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ namespace Whisper.net.Wave;

public sealed class WaveParser
{
private static readonly byte[] expectedSubFormatForPcm = new byte[] { 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71 };
private readonly Stream waveStream;
private ushort channels;
private uint sampleRate;
Expand Down Expand Up @@ -46,7 +47,7 @@ public async Task<float[]> GetAvgSamplesAsync(CancellationToken cancellationToke
var samplesCount = GetSamplesCount();
var samples = new float[samplesCount];

var buffer = new byte[4096];
var buffer = new byte[2048 * channels];

var sampleIndex = 0;
int bytesRead;
Expand Down Expand Up @@ -131,14 +132,27 @@ public float[] GetChannelSamples(int channelIndex = 0)
}

private void Initialize()
{
InitializeCore(useAsync: false).GetAwaiter().GetResult();
}

private Task InitializeAsync()
{
return InitializeCore(useAsync: true);
}

private async Task InitializeCore(bool useAsync)
{
if (wasInitialized)
{
return;
}

var buffer = new byte[12];
var actualRead = waveStream.Read(buffer, 0, 12);
var actualRead = useAsync
? await waveStream.ReadAsync(buffer, 0, 12)
: waveStream.Read(buffer, 0, 12);

if (actualRead != 12)
{
throw new CorruptedWaveException("Invalid wave file, the size is too small.");
Expand All @@ -163,7 +177,10 @@ private void Initialize()
while (true)
{
var nextChunkHeader = new byte[8];
actualRead = waveStream.Read(nextChunkHeader, 0, 8);
actualRead = useAsync
? await waveStream.ReadAsync(nextChunkHeader, 0, 8)
: waveStream.Read(nextChunkHeader, 0, 8);

if (actualRead != 8)
{
throw new CorruptedWaveException("Invalid wave file, cannot read next chunk.");
Expand All @@ -188,7 +205,8 @@ private void Initialize()
else
{
var restOfChunk = new byte[chunkSize];
waveStream.Read(restOfChunk, 0, chunkSize);
_ = useAsync ? await waveStream.ReadAsync(restOfChunk, 0, chunkSize)
: waveStream.Read(restOfChunk, 0, chunkSize);
}
}

Expand All @@ -198,162 +216,40 @@ private void Initialize()
}

var fmtBuffer = new byte[fmtChunkSize];
actualRead = waveStream.Read(fmtBuffer, 0, fmtChunkSize);
actualRead = useAsync ? await waveStream.ReadAsync(fmtBuffer, 0, fmtChunkSize)
: waveStream.Read(fmtBuffer, 0, fmtChunkSize);
if (actualRead != fmtChunkSize)
{
throw new CorruptedWaveException("Invalid wave file, cannot read format chunk.");
}

// Read Format
var format = BitConverter.ToUInt16(fmtBuffer, 0);
if (format != 1)
if (format != 1 && format != 65534) // Allow both standard PCM and WAVE_FORMAT_EXTENSIBLE
{
throw new CorruptedWaveException("Unsupported wave file");
}

channels = BitConverter.ToUInt16(fmtBuffer, 2);
sampleRate = BitConverter.ToUInt32(fmtBuffer, 4);
if (sampleRate != 16000)
{
throw new NotSupportedWaveException("Only 16KHz sample rate is supported.");
}

// Skip Average bytes rate 8 => 12

// Skip Block Allign 12 => 14

bitsPerSample = BitConverter.ToUInt16(fmtBuffer, 14);
if (bitsPerSample != 16)
{
throw new NotSupportedWaveException("Only 16 bits per sample is supported.");
}
// Until now we have read 18 bytes in format, the rest is cbSize, averageBytesRate, and is ignored for now.
if (fmtChunkSize > 18)
{
if (waveStream.CanSeek)
{
waveStream.Seek(fmtChunkSize - 18, SeekOrigin.Current);
}
else
{
var restOfBuffer = new byte[fmtChunkSize - 18];
waveStream.Read(restOfBuffer, 0, fmtChunkSize - 18);
}
}

// Seek data chuunk
// Read chunk name and size
waveStream.Read(buffer, 0, 8);
while (buffer[0] != 'd' || buffer[1] != 'a' || buffer[2] != 't' || buffer[3] != 'a')
{
var chunkSize = BitConverter.ToInt32(buffer, 4);
if (chunkSize < 0)
{
throw new CorruptedWaveException("Invalid wave chunk size.");
}
if (waveStream.CanSeek)
{
waveStream.Seek(chunkSize, SeekOrigin.Current);
}
else
{
var restOfChunk = new byte[chunkSize];
waveStream.Read(restOfChunk, 0, chunkSize);
}

actualRead = waveStream.Read(buffer, 0, 8);
if (actualRead != 8)
{
throw new CorruptedWaveException("Invalid wave chunk size.");
}
}

dataChunkSize = BitConverter.ToUInt32(buffer, 4);
dataChunkPosition = waveStream.Position;
wasInitialized = true;
}

private async Task InitializeAsync()
{
if (wasInitialized)
{
return;
}

var buffer = new byte[12];
var actualRead = await waveStream.ReadAsync(buffer, 0, 12);
if (actualRead != 12)
{
throw new CorruptedWaveException("Invalid wave file, the size is too small.");
}

//Read RIFF Header
if (buffer[0] != 'R' || buffer[1] != 'I' || buffer[2] != 'F' || buffer[3] != 'F')
{
throw new CorruptedWaveException("Invalid wave file RIFF header.");
}

// Skip FileSize 4 => 8

// Read Wave and Fmt tags
if (buffer[8] != 'W' || buffer[9] != 'A' || buffer[10] != 'V' || buffer[11] != 'E')
{
throw new CorruptedWaveException("Invalid wave file header.");
}

// Search for format chunk
int fmtChunkSize;
while (true)
// If the file is in WAVE_FORMAT_EXTENSIBLE format, we'll need to read the SubFormat field
if (format == 65534)
{
var nextChunkHeader = new byte[8];
actualRead = await waveStream.ReadAsync(nextChunkHeader, 0, 8);
if (actualRead != 8)
{
throw new CorruptedWaveException("Invalid wave file, cannot read next chunk.");
}

var chunkSize = BitConverter.ToInt32(nextChunkHeader, 4);
if (chunkSize < 0)
// Verify that fmtChunkSize is at least 40, which is required for WAVE_FORMAT_EXTENSIBLE
if (fmtChunkSize < 40)
{
throw new CorruptedWaveException("Invalid wave chunk size.");
throw new CorruptedWaveException("Invalid wave format size.");
}

if (nextChunkHeader[0] == 'f' && nextChunkHeader[1] == 'm' && nextChunkHeader[2] == 't' && nextChunkHeader[3] == ' ')
{
fmtChunkSize = chunkSize;
break;
}

if (waveStream.CanSeek)
// The SubFormat field is a GUID, but for PCM data it will be {00000001-0000-0010-8000-00aa00389b71}
// Check this manually, byte by byte
for (var i = 0; i < 16; i++)
{
waveStream.Seek(chunkSize, SeekOrigin.Current);
}
else
{
var restOfChunk = new byte[chunkSize];
await waveStream.ReadAsync(restOfChunk, 0, chunkSize);
if (fmtBuffer[24 + i] != expectedSubFormatForPcm[i])
{
throw new CorruptedWaveException("Unsupported wave file format. Only PCM is supported.");
}
}
}

if (fmtChunkSize < 16)
{
throw new CorruptedWaveException("Invalid wave format size.");
}

var fmtBuffer = new byte[fmtChunkSize];
actualRead = await waveStream.ReadAsync(fmtBuffer, 0, fmtChunkSize);
if (actualRead != fmtChunkSize)
{
throw new CorruptedWaveException("Invalid wave file, cannot read format chunk.");
}

// Read Format
var format = BitConverter.ToUInt16(fmtBuffer, 0);
if (format != 1)
{
throw new CorruptedWaveException("Unsupported wave file");
}

channels = BitConverter.ToUInt16(fmtBuffer, 2);
sampleRate = BitConverter.ToUInt32(fmtBuffer, 4);
if (sampleRate != 16000)
Expand All @@ -370,23 +266,13 @@ private async Task InitializeAsync()
{
throw new NotSupportedWaveException("Only 16 bits per sample is supported.");
}
// Until now we have read 18 bytes in format, the rest is cbSize, averageBytesRate, and is ignored for now.
if (fmtChunkSize > 18)
{
if (waveStream.CanSeek)
{
waveStream.Seek(fmtChunkSize - 18, SeekOrigin.Current);
}
else
{
var restOfBuffer = new byte[fmtChunkSize - 18];
await waveStream.ReadAsync(restOfBuffer, 0, fmtChunkSize - 18);
}
}

// Seek data chuunk
// Read chunk name and size
await waveStream.ReadAsync(buffer, 0, 8);

_ = useAsync ? await waveStream.ReadAsync(buffer, 0, 8)
: waveStream.Read(buffer, 0, 8);

while (buffer[0] != 'd' || buffer[1] != 'a' || buffer[2] != 't' || buffer[3] != 'a')
{
var chunkSize = BitConverter.ToInt32(buffer, 4);
Expand All @@ -401,10 +287,13 @@ private async Task InitializeAsync()
else
{
var restOfChunk = new byte[chunkSize];
await waveStream.ReadAsync(restOfChunk, 0, chunkSize);
_ = useAsync ? await waveStream.ReadAsync(restOfChunk, 0, chunkSize)
: waveStream.Read(restOfChunk, 0, chunkSize);
}

actualRead = await waveStream.ReadAsync(buffer, 0, 8);
actualRead = useAsync ? await waveStream.ReadAsync(buffer, 0, 8)
: waveStream.Read(buffer, 0, 8);

if (actualRead != 8)
{
throw new CorruptedWaveException("Invalid wave chunk size.");
Expand Down
4 changes: 4 additions & 0 deletions examples/DiarizationWhisperExample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@

await foreach (var result in processor.ProcessAsync(fileStream))
{
// TODO: here, check the wave stream to see in which channel the diff is the highest for the specified time interval
// 1. Get the wave position for the specified time interval
// 2. Get the wave data for the specified time interval
// 3. Iterate in the wave data to find the channel with the highest diff
Console.WriteLine($"{result.Start}->{result.End}: {result.Text}");
}

Expand Down

0 comments on commit ad161d4

Please sign in to comment.