Here's my first attempt to do reinforced machine learning code.
The program looks for price changes with relations to 3 different moving averages. The program backtests while building the state tables.
Input file format is a csv file built from Yahoo Quotes historical data.
Output: to console what a 10,000 investment (Starting at the beginning of 2017) would get.
You will need ta-lib package from ta-lib.org.
package reinforcedLearning;
import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.TreeMap;
import com.tictactec.ta.lib.Core;
import com.tictactec.ta.lib.MInteger;
public class ReinforcedLearningWith3MovingAverages {
/**
* stateTableCnt - # of times this state occurs
* StateTablePriceTotal - sum of all prices for this state
* using the combination of the 2 tables we get the average price change for the state.
* using the average price we get the action: buy (> 1.0) or sell (< 1.0)
* stateTableTransition - allows us to see where this state may go next.
*/
static TreeMap<String, Double> stateTableCnt = new TreeMap<String, Double>();
static TreeMap<String, Double> stateTablePriceChangeTotal = new TreeMap<String, Double>();
static TreeMap<String, Double> stateTableTransition = new TreeMap<String, Double>();
public static void main(String args[]) throws Exception {
Core core = new Core();
ArrayList<Double> closes = new ArrayList<>();
ArrayList<String> dates = new ArrayList<String>();
BufferedReader br = new BufferedReader(new FileReader("spy.csv"));
String in = "";
br.readLine(); // skip header
// load data into arraylist
while ((in = br.readLine()) != null) {
String ins[] = in.split(",");
closes.add(Double.parseDouble(ins[5]));
dates.add(ins[0]);
}
br.close();
// move data to array for ta-lib.
double[] dclose = new double[closes.size()];
for (int ix = 0; ix < closes.size(); ix++)
dclose[ix] = closes.get(ix);
int slowPeriod = 21;
int mediumPeriod = 15;
int fastPeriod = 9;
int daysOut = 2; // looking at a 2 day price differences.
double[] slowerMA = new double[closes.size()];
double[] fasterMA = new double[closes.size()];
double[] middleMA = new double[closes.size()];
MInteger outBegIdx = new MInteger();
MInteger outNBElement = new MInteger();
core.sma(0, dclose.length - 1, dclose, slowPeriod, outBegIdx, outNBElement, slowerMA);
// realign the data because ta-lib always starts its output at 0.
slowerMA = realign(slowerMA, outBegIdx.value);
core.sma(0, dclose.length - 1, dclose, fastPeriod, outBegIdx, outNBElement, fasterMA);
fasterMA = realign(fasterMA, outBegIdx.value);
core.sma(0, dclose.length - 1, dclose, mediumPeriod, outBegIdx, outNBElement, middleMA);
middleMA = realign(middleMA, outBegIdx.value);
String previousState = "";
double cashAccount = 10000; // starting cash
double stockHoldings = 0;
for (int ix = slowPeriod * 2; ix < (dclose.length - daysOut); ix++) {
// * 2 because some data is not available yet.
if (slowerMA[ix] == 0) // such as here.
continue;
String resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);
double d = 0;
if (stateTableCnt.containsKey(resultState) == false) {
stateTableCnt.put(resultState, 0.);
} else {
d = stateTableCnt.get(resultState);
}
d++;
stateTableCnt.replace(resultState, d);
if (previousState.length() == 0) {
previousState = resultState;
continue;
}
d = 0;
if (stateTablePriceChangeTotal.containsKey(resultState) == false) {
stateTablePriceChangeTotal.put(resultState, 0.);
} else {
d = stateTablePriceChangeTotal.get(resultState);
}
d += (dclose[ix + daysOut] / dclose[ix]);
stateTablePriceChangeTotal.replace(resultState, d);
d = 0;
String transitonState = previousState + ":" + resultState;
if (stateTableTransition.containsKey(transitonState) == false)
stateTableTransition.put(transitonState, 0.);
else
d = stateTableTransition.get(transitonState);
d++;
stateTableTransition.replace(transitonState, d);
previousState = resultState;
if (dates.get(ix).compareTo("2017") > 0) {
// just work on data from 2017 and beyond.
int ixx = ix + daysOut + 2;
// go out at least 2 days so we don't use states that have
// already been recorded.
if (ixx < dates.size()) {
resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);
double got = getTransitionChangeValue(resultState, 1);
// System.out.println(got);
if (got != Double.MAX_VALUE) // this state is not in table.
if (got > 1) { // buy situation
if (stockHoldings > 0) {
; // can't buy, all money spent.
} else {
// System.out.println("buy " + dates.get(ixx) +
// " @ " + dclose[ixx]);
stockHoldings = cashAccount / dclose[ixx];
cashAccount = 0;
}
} else { // sell situation.
if (stockHoldings > 0) {
cashAccount = stockHoldings * dclose[ixx];
stockHoldings = 0;
// System.out.println("sell " + dates.get(ixx) +
// " @ " + dclose[ixx] + " total " +
// cashAccount);
} else {
; // do nothing, no stocks to sell
}
}
}
}
}
if (stockHoldings > 0) {
// this is the end of the run, if holding shares dump shares
cashAccount = stockHoldings * dclose[dclose.length - 1];
stockHoldings = 0;
// System.out.println("sell at end of run" + dates.get(dclose.length
// - 1) + " @ " + dclose[dclose.length - 1] + " total " +
// df.format(cashAccount));
}
DecimalFormat df = new DecimalFormat("#.##");
System.out.println(fastPeriod + ";" + mediumPeriod + ";" + slowPeriod + ";" + df.format(cashAccount));
}
private static double getTransitionChangeValue(String state, int i) {
/* THIS METHOD IS RECURSIVE */
if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
// let calling method know that this state has not been recorded
return Double.MAX_VALUE;
double got;
// start with what this state is worth
got = (stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state));
if (i > 0) {
double sum = 0;
int cnt = 0;
for (String key : stateTableTransition.keySet()) {
// now get average of what all of the subsequent states.
String split[] = key.split(">");
if (split[0].compareTo(state) == 0) {
sum += getTransitionChangeValue(split[1], i - 1);
cnt++;
}
}
if (cnt > 0)
got += (sum / cnt);
}
return got;
}
static String buildResultState(double[] dclose, double[] fasterMA, double[] middleMA, double[] slowerMA, int ix) {
String resultState = "";
resultState += (dclose[ix] > fasterMA[ix] ? ">Fast" : "<Fast");
resultState += (dclose[ix] > middleMA[ix] ? ">Middle" : "<Middle");
resultState += (dclose[ix] > slowerMA[ix] ? ">Slow" : "<Slow");
resultState += (fasterMA[ix] > slowerMA[ix] ? "_Fast>Slow" : "_Fast<Slow");
resultState += (fasterMA[ix] > middleMA[ix] ? "_Fast>Middle" : "_Fast<Middle");
resultState += (slowerMA[ix] > middleMA[ix] ? "_Slow>Middle" : "_Slow<Middle");
return resultState;
}
static double[] realign(double in[], int offset) throws Exception {
if (offset <= 0)
throw new Exception("offset must be greater than 0");
for (int ix = in.length - 1; ix > offset - 1; ix--)
in[ix] = in[ix - offset];
for (int ix = offset; ix >= 0; ix--)
in[ix] = 0;
return in;
}
}
4 comments:
Comment on posting
I got comments to work. Updates to RL Code see https://artificialinvestor.blogspot.com/2017/11/update-to-rl-code.html
Update to RL Code
changed the method to get the value -
1) state string splitter should be ':'
2) substates are averaged by the probablity of being in this state.
3) the final value is averaged with the substate value.
static double getTransitionChangeValue(String state, int i) {
/* THIS METHOD IS RECURSIVE */
if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
// let calling method know that this state has not been recorded
return Double.MAX_VALUE;
double got;
// start with what this state is worth
got = (stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state));
if (i > 0) {
double sum = 0;
int cnt = 0;
for (String key : stateTableTransition.keySet()) {
// now get average of what all of the subsequent states.
String split[] = key.split(":");
if (split[0].compareTo(state) == 0) {
sum += getTransitionChangeValue(split[1], i - 1)
* (stateTableTransition.get(key)/stateTableCnt.get(state));
cnt++;
}
}
if (cnt > 0) {
got += sum;
got /= 2;
}
}
return got;
}
Complete updated code at
http://artificialinvestor.blogspot.com/2017/11/source-code-reinforcedlearningwith3movi.html
Post a Comment