Tuesday, November 14, 2017

Source Code: ReinforcedLearningWith3MovingAverages Version 1.1

Here's my 2nd attempt to do reinforced machine learning code.

The program looks for price changes with relations to 3 different moving averages.  The program backtests while building the state tables. 

Input file format is a csv file built from Yahoo Quotes historical data.

Output: to console what a 10,000 investment (Starting at the beginning of 2017) would get.

You will need ta-lib package from ta-lib.org.

changes:
  1. removed stateTableTranitionPriceChangeTotal table 
  2. corrected iteration logic when getting value for state in the    getTranitionChangeValue method.
  3. look ahead period changed to 1 (daysOut variable).
  4. initialize the static tables inside main method, makes it easier to implement for loops when testing to find best periods.


package reinforcedLearning;

import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.TreeMap;

import com.tictactec.ta.lib.Core;
import com.tictactec.ta.lib.MInteger;

/**
 * @author joe mcverry
 *         http://mcverryreport.com
 *         http://americancoders.com
 *         http://artificialinvestor.blogspot.com/
 *           usacoder@gmail...
 * @version 1.1  http://artificialinvestor.blogspot.com/
 *
 */

public class ReinforcedLearningWith3MovingAverages {

    /**
     * stateTableCnt - #  of times this state occurs
     * StateTablePriceChangeTotal - sum of all price changes for this state
     * using the combination of the 2 tables we get the average price change for the state.
     * using the average price we get the action: buy (> 1.0) or sell (< 1.0)
     * stateTableTransition - allows us to see where this state may go next.
      */

   
    static TreeMap<String, Double> stateTableCnt;
   
    static TreeMap<String, Double> stateTablePriceChangeTotal;

    static TreeMap<String, Double> stateTableTransitionCnt;

   

    public static void main(String args[]) throws Exception {

        Core core = new Core();
        ArrayList<Double> closes = new ArrayList<>();

        ArrayList<String> dates = new ArrayList<String>();
        BufferedReader br = new BufferedReader(new FileReader("spy.csv"));
        String in = "";
        br.readLine(); // skip header

        while ((in = br.readLine()) != null) {
            String ins[] = in.split(",");
            closes.add(Double.parseDouble(ins[5]));
            dates.add(ins[0]);
        }
        br.close();

        double[] dclose = new double[closes.size()];
        for (int ix = 0; ix < closes.size(); ix++)
            dclose[ix] = closes.get(ix);

        int slowPeriod = 21;
        int mediumPeriod = 15;
        int fastPeriod = 2;

         int daysOut = 1;  // looking at a 1 day price differences.
       
        double[] slowerMA = new double[closes.size()];
        double[] fasterMA = new double[closes.size()];
        double[] middleMA = new double[closes.size()];
       
        MInteger outBegIdx = new MInteger();
        MInteger outNBElement = new MInteger();

        core.sma(0, dclose.length - 1, dclose, slowPeriod, outBegIdx, outNBElement, slowerMA);
        // realign the data because ta-lib always starts its output at 0.
        slowerMA = realign(slowerMA, outBegIdx.value);

        core.sma(0, dclose.length - 1, dclose, fastPeriod, outBegIdx, outNBElement, fasterMA);
        fasterMA = realign(fasterMA, outBegIdx.value);

        core.sma(0, dclose.length - 1, dclose, mediumPeriod, outBegIdx, outNBElement, middleMA);
        middleMA = realign(middleMA, outBegIdx.value);

        String previousState = "";

        stateTableCnt = new TreeMap<String, Double>();
        stateTablePriceChangeTotal = new TreeMap<String, Double>();
        stateTableTransitionCnt = new TreeMap<String, Double>();
       
        double cashAccount = 10000; // starting cash
        double stockHoldings = 0;
        for (int ix = slowPeriod * 2; ix < (dclose.length - daysOut); ix++) {
            // skip ahead by a factor of 2 because some data may not available yet.

            if (slowerMA[ix] == 0)  // such as here.
                continue;
           
            String resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);

            Double d;
            if (stateTableCnt.containsKey(resultState) == false) {
                stateTableCnt.put(resultState, 0.);
                stateTablePriceChangeTotal.put(resultState, 0.);
            }
           
            d = stateTableCnt.get(resultState);
            d++;
            stateTableCnt.replace(resultState, d);

            d = stateTablePriceChangeTotal.get(resultState);
            d += (dclose[ix + daysOut] / dclose[ix]);  // price change in x days
            stateTablePriceChangeTotal.replace(resultState, d);

            if (previousState.length() == 0) {
                previousState = resultState;
                continue;
            }

            String transitonState = previousState + ":" + resultState;
            if (stateTableTransitionCnt.containsKey(transitonState) == false) {
                stateTableTransitionCnt.put(transitonState, 0.);
            }
           
            d = stateTableTransitionCnt.get(transitonState);
            d++;
            stateTableTransitionCnt.replace(transitonState, d);

       

            previousState = resultState;

            /* this is where back testing takes place */
            if (dates.get(ix).compareTo("2017") > 0) { 
                // just work on data from 2017 and beyond.

                int ixx = ix + daysOut + 2;
                // go out at least 2 days so we don't test states that have
                // already been recorded.  AVOID THE LOOK-AHEAD PROBLEM.
                if (ixx < dates.size()) { 
                    // may get an out of bounds array exception; so test for it.
                    resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);

                    double got = getTransitionChangeValue(resultState, 0);
                    // System.out.println(got);
                    if (got != Double.MAX_VALUE) // this state is not in table.
                        if (got > 1) { // buy situation
                            if (stockHoldings > 0) {
                                ; // can't buy, all money spent.
                            } else {
//                                System.out.println("buy " + dates.get(ixx) + " @ " + dclose[ixx]);
                                stockHoldings = cashAccount / dclose[ixx];
                                cashAccount = 0;
                            }
                        } else { // sell situation.
                            if (stockHoldings > 0) {
                                cashAccount = stockHoldings * dclose[ixx];
                                stockHoldings = 0;
//                                System.out.println("sell " + dates.get(ixx) + " @ " + dclose[ixx] + " total " + cashAccount);

                            } else {
                                ; // do nothing, no stocks to sell
                            }
                        }
                }
            }

        }
        DecimalFormat df = new DecimalFormat("#.##");

        if (stockHoldings > 0) {
            // this is the end of the run, if holding shares dump shares
            cashAccount = stockHoldings * dclose[dclose.length - 1];
            stockHoldings = 0;
            //System.out.println("sell at end of run" + dates.get(dclose.length - 1) + " @ " + dclose[dclose.length - 1] + " total " + df.format(cashAccount));

        }
       
        System.out.println(fastPeriod + ";" + mediumPeriod + ";" + slowPeriod + ";" + df.format(cashAccount));
   

    }

    /**
     * @param state
     * @param recursionLevel
     * @return
     */
    static double getTransitionChangeValue(String state, int recursionLevel) {

        // computes what this state is worth
        // if iterating include sum of results for succeding transition states.
       
        if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
            return Double.MAX_VALUE;

            double ret = stateTablePriceChangeTotal.get(state) /  stateTableCnt.get(state);

            if (recursionLevel == 0)
                return ret;
           
            double sum = 0;
            double cnt = 0;
           
            for (String key : stateTableTransitionCnt.keySet()) {

                String subKeys[] = key.split(":");

                if (subKeys[0].compareTo(state) == 0) {
                    double presum = getTransitionChangeValue(subKeys[1], recursionLevel-1)   ;  // what will the next event get us.
                    presum *= stateTableTransitionCnt.get(key); //its weight
                    sum += presum;
                    cnt += stateTableTransitionCnt.get(key); // weight for averaging.
                }
            }

        return (ret+(sum/cnt))/2; // average of the 2 outcomes. the reward.
    }


   
    /**
     * @param dclose
     * @param fasterMA
     * @param middleMA
     * @param slowerMA
     * @param ix
     * @return
     */
    static String buildResultState(double[] dclose, double[] fasterMA, double[] middleMA, double[] slowerMA, int ix) {
        String resultState = "";
        resultState += (dclose[ix] > fasterMA[ix] ? ">Fast" : "<Fast");
        resultState += (dclose[ix] > middleMA[ix] ? ">Middle" : "<Middle");
        resultState += (dclose[ix] > slowerMA[ix] ? ">Slow" : "<Slow");
        resultState += (fasterMA[ix] > slowerMA[ix] ? "_Fast>Slow" : "_Fast<Slow");
        resultState += (fasterMA[ix] > middleMA[ix] ? "_Fast>Middle" : "_Fast<Middle");
        resultState += (slowerMA[ix] > middleMA[ix] ? "_Slow>Middle" : "_Slow<Middle");
        return resultState;
    }

    /**
     * @param in
     * @param offset
     * @return
     * @throws Exception
     */
    static double[] realign(double in[], int offset) throws Exception {

        if (offset <= 0)
            throw new Exception("offset must be greater than 0");
        for (int ix = in.length - 1; ix > offset - 1; ix--)
            in[ix] = in[ix - offset];

        for (int ix = offset; ix >= 0; ix--)
            in[ix] = 0;

        return in;

    }

}

No comments: