Skip to content

Commit

Permalink
Added sequence generation for Markov Chains. Added static utility fun…
Browse files Browse the repository at this point in the history
…ctions for finding grams and manipulating sequences. Added tests for new methods and increased overal coverage.
  • Loading branch information
abrisene committed Sep 2, 2021
1 parent 74513df commit 2c24a24
Show file tree
Hide file tree
Showing 3 changed files with 478 additions and 106 deletions.
3 changes: 2 additions & 1 deletion examples/basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ const chain = new MarkovChain({
});

// chain.addSequences(source);
const pickA = chain.pick();
// const pickA = chain.pick();
const pickA = chain.generate({ start: ['a'], trim: false });

console.log(chain.sequences);
console.log(pickA);
211 changes: 184 additions & 27 deletions src/__tests__/markov.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

import { MarkovChain, MarkovChainDTO, MarkovChainGramDTO, GramDictionary, Random, CONSTANTS } from '..';
import { MCGeneratorOptions, MCDirectionOption, MCGeneratorStaticOptions } from '../structures';
// import { MarkovChainSequenceDTO } from '../structures';

/**
Expand All @@ -21,6 +22,14 @@ const defaultOptions = {
endDelimiter: CONSTANTS.MC_END_DELIMITER,
};

const defaultGenOptions = {
min: 1,
max: 100,
direction: 'next' as MCDirectionOption,
strict: true,
trim: true,
};

const defaultDTO: MarkovChainDTO = { ...defaultOptions, sequences: [], grams: {} };
const defaultDTO1 = { ...defaultDTO, maxOrder: 1, sequences: [], grams: {} };
const defaultDTO2 = { ...defaultDTO, maxOrder: 2, sequences: [], grams: {} };
Expand Down Expand Up @@ -50,6 +59,8 @@ function validateDTO(m: MarkovChainDTO, ref = defaultDTO) {

function validateInstance(m: MarkovChain, ref = defaultDTO) {
const data = m.serialize();
expect(m.dto).toEqual(data);
expect(m.model).toEqual(data);
expect(m.maxOrder).toEqual(ref.maxOrder);
expect(m.delimiter).toEqual(ref.delimiter);
expect(m.startDelimiter).toEqual(ref.startDelimiter);
Expand Down Expand Up @@ -84,6 +95,21 @@ function validateGrams(m: MarkovChainDTO) {
});
}

function validateGen(model: MarkovChainDTO, output: string[], options: MCGeneratorOptions = defaultGenOptions) {
expect(output).toBeDefined(); // If we're testing this, we expect it to be defined.
if (output !== undefined) {
expect(output.length).toBeGreaterThan(options.min || defaultGenOptions.min);
expect(output.length).toBeLessThan(options.max || defaultGenOptions.max);

if (options.trim) {
expect(output.filter(v => (v === model.startDelimiter || v === model.endDelimiter)).length).toEqual(0);
} else {
// expect(output.filter(v => (v === model.startDelimiter || v === model.endDelimiter)).length).toBeGreaterThanOrEqual(1);
expect(output.filter(v => (v === model.startDelimiter || v === model.endDelimiter)).length).toBeLessThanOrEqual(2);
}
}
}

/**
# Test Constants
*/
Expand Down Expand Up @@ -396,6 +422,71 @@ describe('Markov Chain', () => {
});
it('can generate sequences a markov chain', () => {
const eng = engine.clone();

// Default
const optD0: MCGeneratorStaticOptions = { model: dtoA3 };
const optD1: MCGeneratorStaticOptions = { model: dtoA3, direction: 'last' };
const optD2: MCGeneratorStaticOptions = { model: dtoA3, strict: false };
const optD3: MCGeneratorStaticOptions = { model: dtoA3, trim: false };

const genD0 = MarkovChain.generate(optD0);
const genD1 = MarkovChain.generate(optD1); // Backward
const genD2 = MarkovChain.generate(optD2); // Unstrict
const genD3 = MarkovChain.generate(optD3); // Untrimmed

expect(sA3.map(e => e.join())).toContain(genD0.join());
expect(sA3.map(e => e.join())).toContain(genD1.join());
expect(sA3.map(e => e.join())).toContain(genD2.join());
validateGen(dtoA3, genD0, optD0);
validateGen(dtoA3, genD1, optD1);
validateGen(dtoA3, genD2, optD2);
validateGen(dtoA3, genD3, optD3);

// Starting Values
const optS1: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['a', 'n'] };
const optS2: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['n', 'a'], direction: 'last' };
const optS3: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['a'], mask: ['l'] };
const optS4: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['a', 'n'], order: 2 };

const genS1 = MarkovChain.generate(optS1); // Forward
const genS2 = MarkovChain.generate(optS2); // Backward
const genS3 = MarkovChain.generate(optS3); // Masked
const genS4 = MarkovChain.generate(optS4); // Order

expect(genS1.join('')).toEqual('anna');
expect(genS2.join('')).toEqual('anna');
// expect(genS3.join('')).toEqual('anna');
expect(genS4.join('')).toEqual('anna');
validateGen(dtoA3, genS1, optS1);
validateGen(dtoA3, genS2, optS2);
validateGen(dtoA3, genS3, optS3);
validateGen(dtoA3, genS4, optS4);

// Non-Strict Cases
const optN1: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['a', 'a', 'a', 'n'], strict: false, order: 10 };
const optN2: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['n', 'a', 'a', 'a'], strict: false, order: 10, direction: 'last' };
const optN3: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['a', 'a', 'a', 'n'], strict: false };
const optN4: MCGeneratorStaticOptions = { model: dtoA3, engine: eng, start: ['n', 'a', 'a', 'a'], strict: false, direction: 'last' };

const genN1 = MarkovChain.generate(optN1); // Forward
const genN2 = MarkovChain.generate(optN2); // Last
const genN3 = MarkovChain.generate(optN3); // Forward (No Order)
const genN4 = MarkovChain.generate(optN4); // Last (No Order)

validateGen(dtoA3, genN1, optN1);
validateGen(dtoA3, genN2, optN2);
validateGen(dtoA3, genN3, optN3);
validateGen(dtoA3, genN4, optN4);

// Failure Cases
const optF0: MCGeneratorStaticOptions = { model: dtoA3, start: ['x', 'y', 'z'] }; // Shouldn't find a Gram
const optF1: MCGeneratorStaticOptions = { model: dtoA3, mask: ['a'] }; // All values should be masked.

const genF0 = MarkovChain.generate(optF0);
const genF1 = MarkovChain.generate(optF1);

expect(genF0).toEqual(optF0.start);
expect(genF1).toEqual([]);
});
it('are immutable', () => {
const mOriginal = MarkovChain.clone(dtoA3);
Expand Down Expand Up @@ -461,13 +552,13 @@ describe('Markov Chain', () => {
const mA5 = new MarkovChain(dtoGU).clone();
const mA6 = new MarkovChain(dto6GU).clone();

expect(mA0.serialize()).toEqual(defaultDTO);
expect(mA1.serialize()).toEqual(defaultDTO);
expect(mA2.serialize()).toEqual(defaultDTO); // Differs from Static
expect(mA3.serialize()).toEqual(dtoU);
expect(mA4.serialize()).toEqual(dto6U);
expect(mA5.serialize()).toEqual(dtoGU);
expect(mA6.serialize()).toEqual(dto6GU);
expect(mA0.dto).toEqual(defaultDTO);
expect(mA1.dto).toEqual(defaultDTO);
expect(mA2.dto).toEqual(defaultDTO); // Differs from Static
expect(mA3.dto).toEqual(dtoU);
expect(mA4.dto).toEqual(dto6U);
expect(mA5.dto).toEqual(dtoGU);
expect(mA6.dto).toEqual(dto6GU);

// Clones with Sequences Stripped
const mB0 = new MarkovChain({}).clone(true); // This won't work.
Expand All @@ -478,24 +569,24 @@ describe('Markov Chain', () => {
const mB5 = new MarkovChain(dtoGU).clone(true);
const mB6 = new MarkovChain(dto6GU).clone(true);

// expect(mB0.serialize()).toEqual(stripSequences(defaultDTO));
// expect(mB1.serialize()).toEqual(stripSequences(defaultDTO));
// expect(mB2.serialize()).toEqual(stripSequences(defaultGramDTO));
expect(mB3.serialize()).toEqual(stripSequences(dtoU));
expect(mB4.serialize()).toEqual(stripSequences(dto6U));
expect(mB5.serialize()).toEqual(stripSequences(dtoGU));
expect(mB6.serialize()).toEqual(stripSequences(dto6GU));
// expect(mB0.dto).toEqual(stripSequences(defaultDTO));
// expect(mB1.dto).toEqual(stripSequences(defaultDTO));
// expect(mB2.dto).toEqual(stripSequences(defaultGramDTO));
expect(mB3.dto).toEqual(stripSequences(dtoU));
expect(mB4.dto).toEqual(stripSequences(dto6U));
expect(mB5.dto).toEqual(stripSequences(dtoGU));
expect(mB6.dto).toEqual(stripSequences(dto6GU));
});
it('create immutable clones', () => {
const mA = new MarkovChain({ sequences: sU });
const mB = mA.clone();
const mC = mB.clone();
mB.addSequences(sC2);
expect(mA.serialize()).toEqual(dtoU);
expect(mB.serialize()).not.toEqual(dtoU);
expect(mC.serialize()).toEqual(dtoU);
expect(mB.serialize()).not.toEqual(mA.serialize());
expect(mB.serialize()).not.toEqual(mC.serialize());
expect(mA.dto).toEqual(dtoU);
expect(mB.dto).not.toEqual(dtoU);
expect(mC.dto).toEqual(dtoU);
expect(mB.dto).not.toEqual(mA.dto);
expect(mB.dto).not.toEqual(mC.dto);
});
it('can add an edge to an existing markov chain', () => {
const m1 = new MarkovChain({ maxOrder: 2 });
Expand Down Expand Up @@ -528,10 +619,10 @@ describe('Markov Chain', () => {
const mA2 = new MarkovChain({ sequences: [] }).addSequence(gA1, false).addSequence(gA2);
const mA = new MarkovChain({ sequences: [] }).addSequence(gA1, false).addSequence(gA2).addSequence(gA3);

expect(mA0.serialize()).toEqual(dtoA1);
expect(mA1.serialize()).toEqual(dtoA1);
expect(mA2.serialize()).toEqual(dtoA2);
expect(mA.serialize()).toEqual(dtoA3);
expect(mA0.dto).toEqual(dtoA1);
expect(mA1.dto).toEqual(dtoA1);
expect(mA2.dto).toEqual(dtoA2);
expect(mA.dto).toEqual(dtoA3);
});
it('can insert a sequence into an existing markov chain', () => {
expect(MarkovChain.addSequence(defaultGramDTO2, gU3, true)).toEqual(dtoGU3IExpected);
Expand All @@ -549,9 +640,9 @@ describe('Markov Chain', () => {
const m0 = new MarkovChain({ sequences: [] }).addSequences(sA3);
const mA = new MarkovChain({ sequences: [] }).addSequences(sA3, false);
const mB = new MarkovChain({ sequences: [] }).addSequences(sB3, false);
expect(m0.serialize()).toEqual(dtoA3);
expect(mA.serialize()).toEqual(dtoA3);
expect(mB.serialize()).toEqual(dtoB3);
expect(m0.dto).toEqual(dtoA3);
expect(mA.dto).toEqual(dtoA3);
expect(mB.dto).toEqual(dtoB3);

// Insertion
const mIB1 = new MarkovChain({ sequences: [] }).addSequences(sB3, 'start');
Expand Down Expand Up @@ -595,6 +686,72 @@ describe('Markov Chain', () => {
expect(pickMask3).toEqual('b');
}
});
it('can generate sequences a markov chain', () => {});
it('can generate sequences a markov chain', () => {
const mA = new MarkovChain({ ...dtoA3, seed: engine.seed, uses: engine.uses });

// Default
const optD0: MCGeneratorOptions = {};
const optD1: MCGeneratorOptions = { direction: 'last' };
const optD2: MCGeneratorOptions = { strict: false };
const optD3: MCGeneratorOptions = { trim: false };
const genD0 = mA.generate(optD0);
const genD1 = mA.generate(optD1); // Backward
const genD2 = mA.generate(optD2); // Unstrict
const genD3 = mA.generate(optD3); // Untrimmed

expect(sA3.map(e => e.join())).toContain(genD0.join());
expect(sA3.map(e => e.join())).toContain(genD1.join());
expect(sA3.map(e => e.join())).toContain(genD2.join());
validateGen(mA.dto, genD0, optD0);
validateGen(mA.dto, genD1, optD1);
validateGen(mA.dto, genD2, optD2);
validateGen(mA.dto, genD3, optD3);

// Starting Values
const optS1: MCGeneratorOptions = { start: ['a', 'n'] };
const optS2: MCGeneratorOptions = { start: ['n', 'a'], direction: 'last' };
const optS3: MCGeneratorOptions = { start: ['a'], mask: ['l'] };
const optS4: MCGeneratorOptions = { start: ['a', 'n'], order: 2 };

const genS1 = mA.generate(optS1); // Forward
const genS2 = mA.generate(optS2); // Backward
const genS3 = mA.generate(optS3); // Masked
const genS4 = mA.generate(optS4); // Order

expect(genS1.join('')).toEqual('anna');
expect(genS2.join('')).toEqual('anna');
// expect(genS3.join('')).toEqual('anna');
expect(genS4.join('')).toEqual('anna');
validateGen(mA.dto, genS1, optS1);
validateGen(mA.dto, genS2, optS2);
validateGen(mA.dto, genS3, optS3);
validateGen(mA.dto, genS4, optS4);

// Non-Strict Cases
const optN1: MCGeneratorOptions = { start: ['a', 'a', 'a', 'n'], strict: false, order: 10 };
const optN2: MCGeneratorOptions = { start: ['n', 'a', 'a', 'a'], strict: false, order: 10, direction: 'last' };
const optN3: MCGeneratorOptions = { start: ['a', 'a', 'a', 'n'], strict: false };
const optN4: MCGeneratorOptions = { start: ['n', 'a', 'a', 'a'], strict: false, direction: 'last' };

const genN1 = mA.generate(optN1); // Forward
const genN2 = mA.generate(optN2); // Last
const genN3 = mA.generate(optN3); // Forward (No Order)
const genN4 = mA.generate(optN4); // Last (No Order)

validateGen(mA.dto, genN1, optN1);
validateGen(mA.dto, genN2, optN2);
validateGen(mA.dto, genN3, optN3);
validateGen(mA.dto, genN4, optN4);

// Failure Cases
const optF0: MCGeneratorOptions = { start: ['x', 'y', 'z'] }; // Shouldn't find a Gram
const optF1: MCGeneratorOptions = { mask: ['a'] }; // All values should be masked.

const genF0 = mA.generate(optF0);
const genF1 = mA.generate(optF1);

expect(genF0).toEqual(optF0.start);
expect(genF1).toEqual([]);
});
});
});
Loading

0 comments on commit 2c24a24

Please sign in to comment.