Created
October 29, 2014 15:19
-
-
Save ClickerMonkey/289af1bab22974bb9c42 to your computer and use it in GitHub Desktop.
Markov chain data structure.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Collection; | |
import java.util.List; | |
import java.util.Random; | |
public class Markov<T> | |
{ | |
protected final MarkovChain<T> root; | |
protected final MarkovChain<T> starters; | |
public Markov() | |
{ | |
root = new MarkovChain<T>( null, null ); | |
starters = new MarkovChain<T>( null, null ); | |
} | |
public void build( int depth, List<T> states ) | |
{ | |
final int stateCount = states.size(); | |
starters.addNext( states.get( 0 ) ); | |
for (int i = 0; i < stateCount; i++) | |
{ | |
final int chainDepth = Math.min( depth, stateCount - i ); | |
MarkovChain<T> chain = root; | |
for (int k = 0; k < chainDepth; k++) | |
{ | |
chain = chain.addNext( states.get( i + k ) ); | |
if (chainDepth != depth) | |
{ | |
chain.addEnd(); | |
} | |
} | |
} | |
} | |
public void build( int depth, T... states ) | |
{ | |
final int stateCount = states.length; | |
starters.addNext( states[0] ); | |
for (int i = 0; i < stateCount; i++) | |
{ | |
final int chainDepth = Math.min( depth, stateCount - i ); | |
MarkovChain<T> chain = root; | |
for (int k = 0; k < chainDepth; k++) | |
{ | |
chain = chain.addNext( states[i + k] ); | |
if (chainDepth != depth) | |
{ | |
chain.addEnd(); | |
} | |
} | |
} | |
} | |
public double probability( T ... states ) | |
{ | |
root.getNext( states[0] ); | |
return 0.0; | |
} | |
public int generateRandom( Random random, T initialState, int min, T[] out ) | |
{ | |
int chainRoot = 0; | |
int chainLength = 0; | |
MarkovChain<T> chain = root.getNext( initialState ); | |
while (chainLength < out.length) | |
{ | |
out[chainLength] = chain.getState(); | |
while (chain == null || !chain.hasNext()) | |
{ | |
T rootState = null; | |
if (chainRoot == chainLength + 1) | |
{ | |
rootState = getRandom( random ); | |
} | |
else | |
{ | |
rootState = out[++chainRoot]; | |
} | |
chain = root.getNext( rootState ); | |
for (int k = chainRoot + 1; k <= chainLength && chain != null; k++) | |
{ | |
chain = chain.getNext( out[k] ); | |
} | |
} | |
chainLength++; | |
if (chainLength > min && chain.isEnd() && chain.isEnd( random )) | |
{ | |
break; | |
} | |
chain = chain.getRandom( random ); | |
} | |
return chainLength; | |
} | |
public T[] generate( Random random, T[] out ) | |
{ | |
return generate( random, getRandom( random ), out ); | |
} | |
public T[] generate( Random random, T initialState, T[] out ) | |
{ | |
int chainRoot = 0; | |
MarkovChain<T> chain = root.getNext( initialState ); | |
for (int i = 0; i < out.length; i++) | |
{ | |
out[i] = chain.getState(); | |
while (chain == null || !chain.hasNext()) | |
{ | |
T rootState = null; | |
if (chainRoot == i + 1) | |
{ | |
rootState = getRandom( random ); | |
} | |
else | |
{ | |
rootState = out[++chainRoot]; | |
} | |
chain = root.getNext( rootState ); | |
for (int k = chainRoot + 1; k <= i && chain != null; k++) | |
{ | |
chain = chain.getNext( out[k] ); | |
} | |
} | |
chain = chain.getRandom( random ); | |
} | |
return out; | |
} | |
public <C extends Collection<T>> C generateRandomSize( Random random, int min, int max, C destination ) | |
{ | |
MarkovChain<T> start = starters.getRandom( random ); | |
T state = start.getState(); | |
destination.add( state ); | |
state = getRandom( start, random ); | |
if (start.hasNext()) | |
{ | |
for (int i = 0; i < max; i++) | |
{ | |
destination.add( state ); | |
MarkovChain<T> next = start.getNext( state ); | |
if (next == null || !next.hasNext()) | |
{ | |
start = root; | |
} | |
else | |
{ | |
start = next; | |
} | |
if (i >= min && next.isEnd() && next.isEnd( random )) | |
{ | |
break; | |
} | |
state = getRandom( start, random ); | |
} | |
} | |
return destination; | |
} | |
public T getRandom( Random random ) | |
{ | |
return getRandom( root, random ); | |
} | |
public T getRandom( Random random, T[] previousStates, int previousStateCount ) | |
{ | |
MarkovChain<T> r = root; | |
for (int i = 0; i < previousStateCount; i++) | |
{ | |
MarkovChain<T> n = r.getNext( previousStates[i] ); | |
if (n == null || !n.hasNext()) | |
{ | |
break; | |
} | |
r = n; | |
} | |
return getRandom( r, random ); | |
} | |
protected T getRandom( MarkovChain<T> chain, Random random ) | |
{ | |
MarkovChain<T> r = chain.getRandom( random ); | |
return (r != null ? r.getState() : null); | |
} | |
protected T getRandomStart( Random random ) | |
{ | |
return getRandom( starters, random ); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.HashMap; | |
import java.util.Map; | |
import java.util.Random; | |
public class MarkovChain<T> | |
{ | |
protected final T state; | |
protected final MarkovChain<T> parent; | |
protected int occurrences; | |
protected Map<T, MarkovChain<T>> next; | |
protected int nextTotalOccurrences; | |
protected int ends; | |
public MarkovChain( T state, MarkovChain<T> parent ) | |
{ | |
this.state = state; | |
this.parent = parent; | |
} | |
public T getState() | |
{ | |
return state; | |
} | |
public int getOccurrences() | |
{ | |
return occurrences; | |
} | |
public void addOccurrence() | |
{ | |
occurrences++; | |
} | |
public MarkovChain<T> addNext( T nextValue ) | |
{ | |
if (next == null) | |
{ | |
next = new HashMap<T, MarkovChain<T>>(); | |
} | |
MarkovChain<T> chain = next.get( nextValue ); | |
if (chain == null) | |
{ | |
chain = new MarkovChain<T>( nextValue, this ); | |
next.put( nextValue, chain ); | |
} | |
chain.addOccurrence(); | |
nextTotalOccurrences++; | |
return chain; | |
} | |
public MarkovChain<T> getRandom( Random random ) | |
{ | |
if (next == null) | |
{ | |
return null; | |
} | |
int i = random.nextInt( nextTotalOccurrences ); | |
for (MarkovChain<T> n : next.values()) | |
{ | |
i -= n.occurrences; | |
if (i <= 0) | |
{ | |
return n; | |
} | |
} | |
return null; | |
} | |
public double getProbability( T value ) | |
{ | |
MarkovChain<T> n = next.get( value ); | |
return (n == null ? 0.0 : (double)n.occurrences / (double)nextTotalOccurrences ); | |
} | |
public void addEnd() | |
{ | |
ends++; | |
} | |
public boolean isEnd( Random random ) | |
{ | |
return random.nextInt( ends + nextTotalOccurrences ) < ends; | |
} | |
public boolean isEnd() | |
{ | |
return (ends > 0); | |
} | |
public MarkovChain<T> getNext( T value ) | |
{ | |
return next.get( value ); | |
} | |
public boolean hasNext() | |
{ | |
return (next != null && !next.isEmpty()); | |
} | |
public boolean hasNext( T value ) | |
{ | |
return next.containsKey( value ); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment