Monday, November 13, 2017

Program: Reinforced Machine Learning Algorithm To Select Optimal Buy/Sell Action.

Here's my first attempt to do reinforced machine learning code.

The program looks for price changes with relations to 3 different moving averages.  The program backtests while building the state tables. 

Input file format is a csv file built from Yahoo Quotes historical data.

Output: to console what a 10,000 investment (Starting at the beginning of 2017) would get.

You will need ta-lib package from ta-lib.org.

 

package reinforcedLearning;

import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.TreeMap;

import com.tictactec.ta.lib.Core;
import com.tictactec.ta.lib.MInteger;

public class ReinforcedLearningWith3MovingAverages {

    /**
     * stateTableCnt - #  of times this state occurs
     * StateTablePriceTotal - sum of all prices for this state
     * using the combination of the 2 tables we get the average price change for the state.
     * using the average price we get the action: buy (> 1.0) or sell (< 1.0)
     * stateTableTransition - allows us to see where this state may go next.
     */

   
    static TreeMap<String, Double> stateTableCnt = new TreeMap<String, Double>();
   
    static TreeMap<String, Double> stateTablePriceChangeTotal = new TreeMap<String, Double>();

    static TreeMap<String, Double> stateTableTransition = new TreeMap<String, Double>();

    public static void main(String args[]) throws Exception {

        Core core = new Core();
        ArrayList<Double> closes = new ArrayList<>();

        ArrayList<String> dates = new ArrayList<String>();
        BufferedReader br = new BufferedReader(new FileReader("spy.csv"));
        String in = "";
        br.readLine(); // skip header
        // load data into arraylist
        while ((in = br.readLine()) != null) {
            String ins[] = in.split(",");
            closes.add(Double.parseDouble(ins[5]));
            dates.add(ins[0]);
        }
        br.close();

        // move data to array for ta-lib.
        double[] dclose = new double[closes.size()];
        for (int ix = 0; ix < closes.size(); ix++)
            dclose[ix] = closes.get(ix);

        int slowPeriod = 21;
        int mediumPeriod = 15;
        int fastPeriod = 9;

        int daysOut = 2;  // looking at a 2 day price differences.
       
        double[] slowerMA = new double[closes.size()];
        double[] fasterMA = new double[closes.size()];
        double[] middleMA = new double[closes.size()];
       
        MInteger outBegIdx = new MInteger();
        MInteger outNBElement = new MInteger();

        core.sma(0, dclose.length - 1, dclose, slowPeriod, outBegIdx, outNBElement, slowerMA);
        // realign the data because ta-lib always starts its output at 0.
        slowerMA = realign(slowerMA, outBegIdx.value);

        core.sma(0, dclose.length - 1, dclose, fastPeriod, outBegIdx, outNBElement, fasterMA);
        fasterMA = realign(fasterMA, outBegIdx.value);

        core.sma(0, dclose.length - 1, dclose, mediumPeriod, outBegIdx, outNBElement, middleMA);
        middleMA = realign(middleMA, outBegIdx.value);

        String previousState = "";

        double cashAccount = 10000; // starting cash
        double stockHoldings = 0;
        for (int ix = slowPeriod * 2; ix < (dclose.length - daysOut); ix++) {
            // * 2 because some data is not available yet.

            if (slowerMA[ix] == 0)  // such as here.
                continue;
           
            String resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);

            double d = 0;
            if (stateTableCnt.containsKey(resultState) == false) {
                stateTableCnt.put(resultState, 0.);
            } else {
                d = stateTableCnt.get(resultState);
            }
            d++;
            stateTableCnt.replace(resultState, d);

            if (previousState.length() == 0) {
                previousState = resultState;
                continue;
            }

            d = 0;
            if (stateTablePriceChangeTotal.containsKey(resultState) == false) {
                stateTablePriceChangeTotal.put(resultState, 0.);
            } else {
                d = stateTablePriceChangeTotal.get(resultState);
            }
            d += (dclose[ix + daysOut] / dclose[ix]);
            stateTablePriceChangeTotal.replace(resultState, d);

            d = 0;

            String transitonState = previousState + ":" + resultState;
            if (stateTableTransition.containsKey(transitonState) == false)
                stateTableTransition.put(transitonState, 0.);
            else
                d = stateTableTransition.get(transitonState);
            d++;
            stateTableTransition.replace(transitonState, d);

            previousState = resultState;

            if (dates.get(ix).compareTo("2017") > 0) {
                // just work on data from 2017 and beyond.
                int ixx = ix + daysOut + 2;
                // go out at least 2 days so we don't use states that have
                // already been recorded.
                if (ixx < dates.size()) {

                    resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);

                    double got = getTransitionChangeValue(resultState, 1);
                    // System.out.println(got);
                    if (got != Double.MAX_VALUE) // this state is not in table.
                        if (got > 1) { // buy situation
                            if (stockHoldings > 0) {
                                ; // can't buy, all money spent.
                            } else {
                                // System.out.println("buy " + dates.get(ixx) +
                                // " @ " + dclose[ixx]);
                                stockHoldings = cashAccount / dclose[ixx];
                                cashAccount = 0;
                            }
                        } else { // sell situation.
                            if (stockHoldings > 0) {
                                cashAccount = stockHoldings * dclose[ixx];
                                stockHoldings = 0;
                                // System.out.println("sell " + dates.get(ixx) +
                                // " @ " + dclose[ixx] + " total " +
                                // cashAccount);

                            } else {
                                ; // do nothing, no stocks to sell
                            }
                        }
                }
            }

        }

        if (stockHoldings > 0) {
            // this is the end of the run, if holding shares dump shares
            cashAccount = stockHoldings * dclose[dclose.length - 1];
            stockHoldings = 0;
            // System.out.println("sell at end of run" + dates.get(dclose.length
            // - 1) + " @ " + dclose[dclose.length - 1] + " total " +
            // df.format(cashAccount));

        }
        DecimalFormat df = new DecimalFormat("#.##");
        System.out.println(fastPeriod + ";" + mediumPeriod + ";" + slowPeriod + ";" + df.format(cashAccount));

    }

    private static double getTransitionChangeValue(String state, int i) {
        /* THIS METHOD IS RECURSIVE */
        if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
            // let calling method know that this state has not been recorded
            return Double.MAX_VALUE;

        double got;
        // start with what this state is worth
        got = (stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state));

        if (i > 0) {
            double sum = 0;
            int cnt = 0;
            for (String key : stateTableTransition.keySet()) {
                // now get average of what all of the subsequent states.
                String split[] = key.split(">");
                if (split[0].compareTo(state) == 0) {
                    sum += getTransitionChangeValue(split[1], i - 1);
                    cnt++;
                }
            }
            if (cnt > 0)
                got += (sum / cnt);
        }
        return got;

    }


   
    static String buildResultState(double[] dclose, double[] fasterMA, double[] middleMA, double[] slowerMA, int ix) {
        String resultState = "";
        resultState += (dclose[ix] > fasterMA[ix] ? ">Fast" : "<Fast");
        resultState += (dclose[ix] > middleMA[ix] ? ">Middle" : "<Middle");
        resultState += (dclose[ix] > slowerMA[ix] ? ">Slow" : "<Slow");
        resultState += (fasterMA[ix] > slowerMA[ix] ? "_Fast>Slow" : "_Fast<Slow");
        resultState += (fasterMA[ix] > middleMA[ix] ? "_Fast>Middle" : "_Fast<Middle");
        resultState += (slowerMA[ix] > middleMA[ix] ? "_Slow>Middle" : "_Slow<Middle");
        return resultState;
    }

    static double[] realign(double in[], int offset) throws Exception {

        if (offset <= 0)
            throw new Exception("offset must be greater than 0");
        for (int ix = in.length - 1; ix > offset - 1; ix--)
            in[ix] = in[ix - offset];

        for (int ix = offset; ix >= 0; ix--)
            in[ix] = 0;

        return in;

    }

}

4 comments:

usacoder said...

Comment on posting

usacoder said...

I got comments to work. Updates to RL Code see https://artificialinvestor.blogspot.com/2017/11/update-to-rl-code.html

usacoder said...

Update to RL Code
changed the method to get the value -
1) state string splitter should be ':'
2) substates are averaged by the probablity of being in this state.
3) the final value is averaged with the substate value.

static double getTransitionChangeValue(String state, int i) {
/* THIS METHOD IS RECURSIVE */
if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
// let calling method know that this state has not been recorded
return Double.MAX_VALUE;

double got;
// start with what this state is worth
got = (stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state));

if (i > 0) {
double sum = 0;
int cnt = 0;
for (String key : stateTableTransition.keySet()) {
// now get average of what all of the subsequent states.
String split[] = key.split(":");
if (split[0].compareTo(state) == 0) {
sum += getTransitionChangeValue(split[1], i - 1)
* (stateTableTransition.get(key)/stateTableCnt.get(state));
cnt++;
}
}
if (cnt > 0) {
got += sum;
got /= 2;
}
}
return got;

}

usacoder said...

Complete updated code at
http://artificialinvestor.blogspot.com/2017/11/source-code-reinforcedlearningwith3movi.html