Here's my 2nd attempt to do reinforced machine learning code.
The program looks for price changes with relations to 3 different moving averages. The program backtests while building the state tables.
Input file format is a csv file built from Yahoo Quotes historical data.
Output: to console what a 10,000 investment (Starting at the beginning of 2017) would get.
You will need ta-lib package from ta-lib.org.
changes:- removed stateTableTranitionPriceChangeTotal table
- corrected iteration logic when getting value for state in the getTranitionChangeValue method.
- look ahead period changed to 1 (daysOut variable).
- initialize the static tables inside main method, makes it easier to implement for loops when testing to find best periods.
package reinforcedLearning;
import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.TreeMap;
import com.tictactec.ta.lib.Core;
import com.tictactec.ta.lib.MInteger;
/**
* @author joe mcverry
* http://mcverryreport.com
* http://americancoders.com
* http://artificialinvestor.blogspot.com/
* usacoder@gmail...
* @version 1.1 http://artificialinvestor.blogspot.com/
*
*/
public class ReinforcedLearningWith3MovingAverages {
/**
* stateTableCnt - # of times this state occurs
* StateTablePriceChangeTotal - sum of all price changes for this state
* using the combination of the 2 tables we get the average price change for the state.
* using the average price we get the action: buy (> 1.0) or sell (< 1.0)
* stateTableTransition - allows us to see where this state may go next.
*/
static TreeMap<String, Double> stateTableCnt;
static TreeMap<String, Double> stateTablePriceChangeTotal;
static TreeMap<String, Double> stateTableTransitionCnt;
public static void main(String args[]) throws Exception {
Core core = new Core();
ArrayList<Double> closes = new ArrayList<>();
ArrayList<String> dates = new ArrayList<String>();
BufferedReader br = new BufferedReader(new FileReader("spy.csv"));
String in = "";
br.readLine(); // skip header
while ((in = br.readLine()) != null) {
String ins[] = in.split(",");
closes.add(Double.parseDouble(ins[5]));
dates.add(ins[0]);
}
br.close();
double[] dclose = new double[closes.size()];
for (int ix = 0; ix < closes.size(); ix++)
dclose[ix] = closes.get(ix);
int slowPeriod = 21;
int mediumPeriod = 15;
int fastPeriod = 2;
int daysOut = 1; // looking at a 1 day price differences.
double[] slowerMA = new double[closes.size()];
double[] fasterMA = new double[closes.size()];
double[] middleMA = new double[closes.size()];
MInteger outBegIdx = new MInteger();
MInteger outNBElement = new MInteger();
core.sma(0, dclose.length - 1, dclose, slowPeriod, outBegIdx, outNBElement, slowerMA);
// realign the data because ta-lib always starts its output at 0.
slowerMA = realign(slowerMA, outBegIdx.value);
core.sma(0, dclose.length - 1, dclose, fastPeriod, outBegIdx, outNBElement, fasterMA);
fasterMA = realign(fasterMA, outBegIdx.value);
core.sma(0, dclose.length - 1, dclose, mediumPeriod, outBegIdx, outNBElement, middleMA);
middleMA = realign(middleMA, outBegIdx.value);
String previousState = "";
stateTableCnt = new TreeMap<String, Double>();
stateTablePriceChangeTotal = new TreeMap<String, Double>();
stateTableTransitionCnt = new TreeMap<String, Double>();
double cashAccount = 10000; // starting cash
double stockHoldings = 0;
for (int ix = slowPeriod * 2; ix < (dclose.length - daysOut); ix++) {
// skip ahead by a factor of 2 because some data may not available yet.
if (slowerMA[ix] == 0) // such as here.
continue;
String resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);
Double d;
if (stateTableCnt.containsKey(resultState) == false) {
stateTableCnt.put(resultState, 0.);
stateTablePriceChangeTotal.put(resultState, 0.);
}
d = stateTableCnt.get(resultState);
d++;
stateTableCnt.replace(resultState, d);
d = stateTablePriceChangeTotal.get(resultState);
d += (dclose[ix + daysOut] / dclose[ix]); // price change in x days
stateTablePriceChangeTotal.replace(resultState, d);
if (previousState.length() == 0) {
previousState = resultState;
continue;
}
String transitonState = previousState + ":" + resultState;
if (stateTableTransitionCnt.containsKey(transitonState) == false) {
stateTableTransitionCnt.put(transitonState, 0.);
}
d = stateTableTransitionCnt.get(transitonState);
d++;
stateTableTransitionCnt.replace(transitonState, d);
previousState = resultState;
/* this is where back testing takes place */
if (dates.get(ix).compareTo("2017") > 0) {
// just work on data from 2017 and beyond.
int ixx = ix + daysOut + 2;
// go out at least 2 days so we don't test states that have
// already been recorded. AVOID THE LOOK-AHEAD PROBLEM.
if (ixx < dates.size()) {
// may get an out of bounds array exception; so test for it.
resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);
double got = getTransitionChangeValue(resultState, 0);
// System.out.println(got);
if (got != Double.MAX_VALUE) // this state is not in table.
if (got > 1) { // buy situation
if (stockHoldings > 0) {
; // can't buy, all money spent.
} else {
// System.out.println("buy " + dates.get(ixx) + " @ " + dclose[ixx]);
stockHoldings = cashAccount / dclose[ixx];
cashAccount = 0;
}
} else { // sell situation.
if (stockHoldings > 0) {
cashAccount = stockHoldings * dclose[ixx];
stockHoldings = 0;
// System.out.println("sell " + dates.get(ixx) + " @ " + dclose[ixx] + " total " + cashAccount);
} else {
; // do nothing, no stocks to sell
}
}
}
}
}
DecimalFormat df = new DecimalFormat("#.##");
if (stockHoldings > 0) {
// this is the end of the run, if holding shares dump shares
cashAccount = stockHoldings * dclose[dclose.length - 1];
stockHoldings = 0;
//System.out.println("sell at end of run" + dates.get(dclose.length - 1) + " @ " + dclose[dclose.length - 1] + " total " + df.format(cashAccount));
}
System.out.println(fastPeriod + ";" + mediumPeriod + ";" + slowPeriod + ";" + df.format(cashAccount));
}
/**
* @param state
* @param recursionLevel
* @return
*/
static double getTransitionChangeValue(String state, int recursionLevel) {
// computes what this state is worth
// if iterating include sum of results for succeding transition states.
if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
return Double.MAX_VALUE;
double ret = stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state);
if (recursionLevel == 0)
return ret;
double sum = 0;
double cnt = 0;
for (String key : stateTableTransitionCnt.keySet()) {
String subKeys[] = key.split(":");
if (subKeys[0].compareTo(state) == 0) {
double presum = getTransitionChangeValue(subKeys[1], recursionLevel-1) ; // what will the next event get us.
presum *= stateTableTransitionCnt.get(key); //its weight
sum += presum;
cnt += stateTableTransitionCnt.get(key); // weight for averaging.
}
}
return (ret+(sum/cnt))/2; // average of the 2 outcomes. the reward.
}
/**
* @param dclose
* @param fasterMA
* @param middleMA
* @param slowerMA
* @param ix
* @return
*/
static String buildResultState(double[] dclose, double[] fasterMA, double[] middleMA, double[] slowerMA, int ix) {
String resultState = "";
resultState += (dclose[ix] > fasterMA[ix] ? ">Fast" : "<Fast");
resultState += (dclose[ix] > middleMA[ix] ? ">Middle" : "<Middle");
resultState += (dclose[ix] > slowerMA[ix] ? ">Slow" : "<Slow");
resultState += (fasterMA[ix] > slowerMA[ix] ? "_Fast>Slow" : "_Fast<Slow");
resultState += (fasterMA[ix] > middleMA[ix] ? "_Fast>Middle" : "_Fast<Middle");
resultState += (slowerMA[ix] > middleMA[ix] ? "_Slow>Middle" : "_Slow<Middle");
return resultState;
}
/**
* @param in
* @param offset
* @return
* @throws Exception
*/
static double[] realign(double in[], int offset) throws Exception {
if (offset <= 0)
throw new Exception("offset must be greater than 0");
for (int ix = in.length - 1; ix > offset - 1; ix--)
in[ix] = in[ix - offset];
for (int ix = offset; ix >= 0; ix--)
in[ix] = 0;
return in;
}
}
No comments:
Post a Comment