Showing posts with label artificial intellegence. Show all posts
Showing posts with label artificial intellegence. Show all posts

Monday, November 13, 2017

Update to RL Code

changed the method to get the value -
 1) state string splitter should be ':'
 2) substates are averaged by the probablity of being in this state.
 3) the final value is averaged with the substate value.

    static double getTransitionChangeValue(String state, int i) {
        /* THIS METHOD IS RECURSIVE */
        if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
            // let calling method know that this state has not been recorded
            return Double.MAX_VALUE;

        double got;
        // start with what this state is worth
        got = (stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state));

        if (i > 0) {
            double sum = 0;
            int cnt = 0;
            for (String key : stateTableTransition.keySet()) {
                // now get average of what all of the subsequent states.
                String split[] = key.split(":");
                if (split[0].compareTo(state) == 0) {
                    sum += getTransitionChangeValue(split[1], i - 1)
                            * (stateTableTransition.get(key)/stateTableCnt.get(state));
                    cnt++;
                }
            }
            if (cnt > 0) {
                got += sum;
                got /= 2;
            }
        }
        return got;

    }


Program: Reinforced Machine Learning Algorithm To Select Optimal Buy/Sell Action.

Here's my first attempt to do reinforced machine learning code.

The program looks for price changes with relations to 3 different moving averages.  The program backtests while building the state tables. 

Input file format is a csv file built from Yahoo Quotes historical data.

Output: to console what a 10,000 investment (Starting at the beginning of 2017) would get.

You will need ta-lib package from ta-lib.org.

 

package reinforcedLearning;

import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.TreeMap;

import com.tictactec.ta.lib.Core;
import com.tictactec.ta.lib.MInteger;

public class ReinforcedLearningWith3MovingAverages {

    /**
     * stateTableCnt - #  of times this state occurs
     * StateTablePriceTotal - sum of all prices for this state
     * using the combination of the 2 tables we get the average price change for the state.
     * using the average price we get the action: buy (> 1.0) or sell (< 1.0)
     * stateTableTransition - allows us to see where this state may go next.
     */

   
    static TreeMap<String, Double> stateTableCnt = new TreeMap<String, Double>();
   
    static TreeMap<String, Double> stateTablePriceChangeTotal = new TreeMap<String, Double>();

    static TreeMap<String, Double> stateTableTransition = new TreeMap<String, Double>();

    public static void main(String args[]) throws Exception {

        Core core = new Core();
        ArrayList<Double> closes = new ArrayList<>();

        ArrayList<String> dates = new ArrayList<String>();
        BufferedReader br = new BufferedReader(new FileReader("spy.csv"));
        String in = "";
        br.readLine(); // skip header
        // load data into arraylist
        while ((in = br.readLine()) != null) {
            String ins[] = in.split(",");
            closes.add(Double.parseDouble(ins[5]));
            dates.add(ins[0]);
        }
        br.close();

        // move data to array for ta-lib.
        double[] dclose = new double[closes.size()];
        for (int ix = 0; ix < closes.size(); ix++)
            dclose[ix] = closes.get(ix);

        int slowPeriod = 21;
        int mediumPeriod = 15;
        int fastPeriod = 9;

        int daysOut = 2;  // looking at a 2 day price differences.
       
        double[] slowerMA = new double[closes.size()];
        double[] fasterMA = new double[closes.size()];
        double[] middleMA = new double[closes.size()];
       
        MInteger outBegIdx = new MInteger();
        MInteger outNBElement = new MInteger();

        core.sma(0, dclose.length - 1, dclose, slowPeriod, outBegIdx, outNBElement, slowerMA);
        // realign the data because ta-lib always starts its output at 0.
        slowerMA = realign(slowerMA, outBegIdx.value);

        core.sma(0, dclose.length - 1, dclose, fastPeriod, outBegIdx, outNBElement, fasterMA);
        fasterMA = realign(fasterMA, outBegIdx.value);

        core.sma(0, dclose.length - 1, dclose, mediumPeriod, outBegIdx, outNBElement, middleMA);
        middleMA = realign(middleMA, outBegIdx.value);

        String previousState = "";

        double cashAccount = 10000; // starting cash
        double stockHoldings = 0;
        for (int ix = slowPeriod * 2; ix < (dclose.length - daysOut); ix++) {
            // * 2 because some data is not available yet.

            if (slowerMA[ix] == 0)  // such as here.
                continue;
           
            String resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);

            double d = 0;
            if (stateTableCnt.containsKey(resultState) == false) {
                stateTableCnt.put(resultState, 0.);
            } else {
                d = stateTableCnt.get(resultState);
            }
            d++;
            stateTableCnt.replace(resultState, d);

            if (previousState.length() == 0) {
                previousState = resultState;
                continue;
            }

            d = 0;
            if (stateTablePriceChangeTotal.containsKey(resultState) == false) {
                stateTablePriceChangeTotal.put(resultState, 0.);
            } else {
                d = stateTablePriceChangeTotal.get(resultState);
            }
            d += (dclose[ix + daysOut] / dclose[ix]);
            stateTablePriceChangeTotal.replace(resultState, d);

            d = 0;

            String transitonState = previousState + ":" + resultState;
            if (stateTableTransition.containsKey(transitonState) == false)
                stateTableTransition.put(transitonState, 0.);
            else
                d = stateTableTransition.get(transitonState);
            d++;
            stateTableTransition.replace(transitonState, d);

            previousState = resultState;

            if (dates.get(ix).compareTo("2017") > 0) {
                // just work on data from 2017 and beyond.
                int ixx = ix + daysOut + 2;
                // go out at least 2 days so we don't use states that have
                // already been recorded.
                if (ixx < dates.size()) {

                    resultState = buildResultState(dclose, fasterMA, middleMA, slowerMA, ix);

                    double got = getTransitionChangeValue(resultState, 1);
                    // System.out.println(got);
                    if (got != Double.MAX_VALUE) // this state is not in table.
                        if (got > 1) { // buy situation
                            if (stockHoldings > 0) {
                                ; // can't buy, all money spent.
                            } else {
                                // System.out.println("buy " + dates.get(ixx) +
                                // " @ " + dclose[ixx]);
                                stockHoldings = cashAccount / dclose[ixx];
                                cashAccount = 0;
                            }
                        } else { // sell situation.
                            if (stockHoldings > 0) {
                                cashAccount = stockHoldings * dclose[ixx];
                                stockHoldings = 0;
                                // System.out.println("sell " + dates.get(ixx) +
                                // " @ " + dclose[ixx] + " total " +
                                // cashAccount);

                            } else {
                                ; // do nothing, no stocks to sell
                            }
                        }
                }
            }

        }

        if (stockHoldings > 0) {
            // this is the end of the run, if holding shares dump shares
            cashAccount = stockHoldings * dclose[dclose.length - 1];
            stockHoldings = 0;
            // System.out.println("sell at end of run" + dates.get(dclose.length
            // - 1) + " @ " + dclose[dclose.length - 1] + " total " +
            // df.format(cashAccount));

        }
        DecimalFormat df = new DecimalFormat("#.##");
        System.out.println(fastPeriod + ";" + mediumPeriod + ";" + slowPeriod + ";" + df.format(cashAccount));

    }

    private static double getTransitionChangeValue(String state, int i) {
        /* THIS METHOD IS RECURSIVE */
        if (stateTableCnt.get(state) == null | stateTablePriceChangeTotal.get(state) == null)
            // let calling method know that this state has not been recorded
            return Double.MAX_VALUE;

        double got;
        // start with what this state is worth
        got = (stateTablePriceChangeTotal.get(state) / stateTableCnt.get(state));

        if (i > 0) {
            double sum = 0;
            int cnt = 0;
            for (String key : stateTableTransition.keySet()) {
                // now get average of what all of the subsequent states.
                String split[] = key.split(">");
                if (split[0].compareTo(state) == 0) {
                    sum += getTransitionChangeValue(split[1], i - 1);
                    cnt++;
                }
            }
            if (cnt > 0)
                got += (sum / cnt);
        }
        return got;

    }


   
    static String buildResultState(double[] dclose, double[] fasterMA, double[] middleMA, double[] slowerMA, int ix) {
        String resultState = "";
        resultState += (dclose[ix] > fasterMA[ix] ? ">Fast" : "<Fast");
        resultState += (dclose[ix] > middleMA[ix] ? ">Middle" : "<Middle");
        resultState += (dclose[ix] > slowerMA[ix] ? ">Slow" : "<Slow");
        resultState += (fasterMA[ix] > slowerMA[ix] ? "_Fast>Slow" : "_Fast<Slow");
        resultState += (fasterMA[ix] > middleMA[ix] ? "_Fast>Middle" : "_Fast<Middle");
        resultState += (slowerMA[ix] > middleMA[ix] ? "_Slow>Middle" : "_Slow<Middle");
        return resultState;
    }

    static double[] realign(double in[], int offset) throws Exception {

        if (offset <= 0)
            throw new Exception("offset must be greater than 0");
        for (int ix = in.length - 1; ix > offset - 1; ix--)
            in[ix] = in[ix - offset];

        for (int ix = offset; ix >= 0; ix--)
            in[ix] = 0;

        return in;

    }

}