/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.broadinstitute.sting.commandline.Advanced;
import org.broadinstitute.sting.commandline.Argument;
import org.broadinstitute.sting.commandline.ArgumentCollection;
import org.broadinstitute.sting.commandline.Input;
import org.broadinstitute.sting.commandline.Output;
import org.broadinstitute.sting.commandline.RodBinding;
import org.broadinstitute.sting.gatk.CommandLineGATK;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.PartitionBy;
import org.broadinstitute.sting.gatk.walkers.PartitionType;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.gatk.walkers.TreeReducible;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.ApplyRecalibration;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.GaussianMixtureModel;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.TrainingSet;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.Tranche;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.TrancheManager;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantDataManager;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantDatum;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantRecalibratorArgumentCollection;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantRecalibratorEngine;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.R.RScriptExecutor;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.help.DocumentedGATKFeature;
import org.broadinstitute.sting.utils.io.Resource;
import org.broadinstitute.sting.utils.variant.GATKVariantContextUtils;
import org.broadinstitute.variant.variantcontext.VariantContext;
import org.broadinstitute.variant.variantcontext.writer.VariantContextWriter;
import org.broadinstitute.variant.vcf.VCFHeader;
import org.broadinstitute.variant.vcf.VCFHeaderLine;

@DocumentedGATKFeature(groupName="Variant Discovery Tools", extraDocs={CommandLineGATK.class})
@PartitionBy(value=PartitionType.NONE)
public class VariantRecalibrator
extends RodWalker<ExpandingArrayList<VariantDatum>, ExpandingArrayList<VariantDatum>>
implements TreeReducible<ExpandingArrayList<VariantDatum>> {
    public static final String VQS_LOD_KEY = "VQSLOD";
    public static final String CULPRIT_KEY = "culprit";
    private static final String PLOT_TRANCHES_RSCRIPT = "plot_Tranches.R";
    @ArgumentCollection
    private VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection();
    @Input(fullName="input", shortName="input", doc="The raw input variants to be recalibrated", required=true)
    public List<RodBinding<VariantContext>> input;
    @Input(fullName="resource", shortName="resource", doc="A list of sites for which to apply a prior probability of being correct but which aren't used by the algorithm", required=false)
    public List<RodBinding<VariantContext>> resource = Collections.emptyList();
    @Output(fullName="recal_file", shortName="recalFile", doc="The output recal file used by ApplyRecalibration", required=true)
    protected VariantContextWriter recalWriter = null;
    @Output(fullName="tranches_file", shortName="tranchesFile", doc="The output tranches file used by ApplyRecalibration", required=true)
    protected File TRANCHES_FILE;
    @Argument(fullName="target_titv", shortName="titv", doc="The expected novel Ti/Tv ratio to use when calculating FDR tranches and for display on the optimization curve output figures. (approx 2.15 for whole genome experiments). ONLY USED FOR PLOTTING PURPOSES!", required=false)
    protected double TARGET_TITV = 2.15;
    @Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true)
    private String[] USE_ANNOTATIONS = null;
    @Argument(fullName="TStranche", shortName="tranche", doc="The levels of novel false discovery rate (FDR, implied by ti/tv) at which to slice the data. (in percent, that is 1.0 for 1 percent)", required=false)
    private double[] TS_TRANCHES = new double[]{100.0, 99.9, 99.0, 90.0};
    @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the variant recalibrator will use variants even if the specified filter name is marked in the input VCF file", required=false)
    private String[] IGNORE_INPUT_FILTERS = null;
    @Output(fullName="rscript_file", shortName="rscriptFile", doc="The output rscript file generated by the VQSR to aid in visualization of the input data and learned model", required=false)
    private File RSCRIPT_FILE = null;
    @Argument(fullName="ts_filter_level", shortName="ts_filter_level", doc="The truth sensitivity level at which to start filtering, used here to indicate filtered variants in the model reporting plots", required=false)
    protected double TS_FILTER_LEVEL = 99.0;
    @Advanced
    @Argument(fullName="trustAllPolymorphic", shortName="allPoly", doc="Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required=false)
    protected Boolean TRUST_ALL_POLYMORPHIC = false;
    private VariantDataManager dataManager;
    private PrintStream tranchesStream;
    private final Set<String> ignoreInputFilterSet = new TreeSet<String>();
    private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine(this.VRAC);

    @Override
    public void initialize() {
        this.dataManager = new VariantDataManager(new ArrayList<String>(Arrays.asList(this.USE_ANNOTATIONS)), this.VRAC);
        if (this.RSCRIPT_FILE != null && !RScriptExecutor.RSCRIPT_EXISTS) {
            Utils.warnUser(logger, String.format("Rscript not found in environment path. %s will be generated but PDF plots will not.", this.RSCRIPT_FILE));
        }
        if (this.IGNORE_INPUT_FILTERS != null) {
            this.ignoreInputFilterSet.addAll(Arrays.asList(this.IGNORE_INPUT_FILTERS));
        }
        try {
            this.tranchesStream = new PrintStream(this.TRANCHES_FILE);
        }
        catch (FileNotFoundException e) {
            throw new UserException.CouldNotCreateOutputFile(this.TRANCHES_FILE, (Exception)e);
        }
        for (RodBinding<VariantContext> rod : this.resource) {
            this.dataManager.addTrainingSet(new TrainingSet(rod));
        }
        if (!this.dataManager.checkHasTrainingSet()) {
            throw new UserException.CommandLineException("No training set found! Please provide sets of known polymorphic loci marked with the training=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf");
        }
        if (!this.dataManager.checkHasTruthSet()) {
            throw new UserException.CommandLineException("No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf");
        }
        HashSet<VCFHeaderLine> hInfo = new HashSet<VCFHeaderLine>();
        ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
        this.recalWriter.writeHeader(new VCFHeader(hInfo));
    }

    @Override
    public ExpandingArrayList<VariantDatum> map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
        ExpandingArrayList<VariantDatum> mapList = new ExpandingArrayList<VariantDatum>();
        if (tracker == null) {
            return mapList;
        }
        for (VariantContext vc : tracker.getValues(this.input, context.getLocation())) {
            if (vc == null || !vc.isNotFiltered() && !this.ignoreInputFilterSet.containsAll(vc.getFilters()) || !VariantDataManager.checkVariationClass(vc, this.VRAC.MODE)) continue;
            VariantDatum datum = new VariantDatum();
            this.dataManager.decodeAnnotations(datum, vc, true);
            datum.loc = this.getToolkit().getGenomeLocParser().createGenomeLoc(vc);
            datum.originalQual = vc.getPhredScaledQual();
            datum.isSNP = vc.isSNP() && vc.isBiallelic();
            datum.isTransition = datum.isSNP && GATKVariantContextUtils.isTransition(vc);
            this.dataManager.parseTrainingSets(tracker, context.getLocation(), vc, datum, this.TRUST_ALL_POLYMORPHIC);
            double priorFactor = QualityUtils.qualToProb(datum.prior);
            datum.prior = Math.log10(priorFactor) - Math.log10(1.0 - priorFactor);
            mapList.add(datum);
        }
        return mapList;
    }

    @Override
    public ExpandingArrayList<VariantDatum> reduceInit() {
        return new ExpandingArrayList<VariantDatum>();
    }

    @Override
    public ExpandingArrayList<VariantDatum> reduce(ExpandingArrayList<VariantDatum> mapValue, ExpandingArrayList<VariantDatum> reduceSum) {
        reduceSum.addAll(mapValue);
        return reduceSum;
    }

    @Override
    public ExpandingArrayList<VariantDatum> treeReduce(ExpandingArrayList<VariantDatum> lhs, ExpandingArrayList<VariantDatum> rhs) {
        rhs.addAll(lhs);
        return rhs;
    }

    @Override
    public void onTraversalDone(ExpandingArrayList<VariantDatum> reduceSum) {
        this.dataManager.setData(reduceSum);
        this.dataManager.normalizeData();
        GaussianMixtureModel goodModel = this.engine.generateModel(this.dataManager.getTrainingData());
        this.engine.evaluateData(this.dataManager.getData(), goodModel, false);
        ExpandingArrayList<VariantDatum> negativeTrainingData = this.dataManager.selectWorstVariants(this.VRAC.PERCENT_BAD_VARIANTS, this.VRAC.MIN_NUM_BAD_VARIANTS);
        GaussianMixtureModel badModel = this.engine.generateModel(negativeTrainingData);
        this.engine.evaluateData(this.dataManager.getData(), badModel, true);
        while (badModel.failedToConverge && this.VRAC.MAX_GAUSSIANS > 4) {
            logger.info("Negative model failed to converge. Retrying...");
            --this.VRAC.MAX_GAUSSIANS;
            badModel = this.engine.generateModel(negativeTrainingData);
            this.engine.evaluateData(this.dataManager.getData(), goodModel, false);
            this.engine.evaluateData(this.dataManager.getData(), badModel, true);
        }
        if (badModel.failedToConverge || goodModel.failedToConverge) {
            throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --percentBadVariants 0.05, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)");
        }
        this.engine.calculateWorstPerformingAnnotation(this.dataManager.getData(), goodModel, badModel);
        int nCallsAtTruth = TrancheManager.countCallsAtTruth(this.dataManager.getData(), Double.NEGATIVE_INFINITY);
        TrancheManager.TruthSensitivityMetric metric = new TrancheManager.TruthSensitivityMetric(nCallsAtTruth);
        List<Tranche> tranches = TrancheManager.findTranches(this.dataManager.getData(), this.TS_TRANCHES, metric, this.VRAC.MODE);
        this.tranchesStream.print(Tranche.tranchesString(tranches));
        double lodCutoff = 0.0;
        for (Tranche tranche : tranches) {
            if (MathUtils.compareDoubles(tranche.ts, this.TS_FILTER_LEVEL, 1.0E-4) != 0) continue;
            lodCutoff = tranche.minVQSLod;
        }
        logger.info("Writing out recalibration table...");
        this.dataManager.writeOutRecalibrationTable(this.recalWriter);
        if (this.RSCRIPT_FILE != null) {
            logger.info("Writing out visualization Rscript file...");
            this.createVisualizationScript(this.dataManager.getRandomDataForPlotting(6000), goodModel, badModel, lodCutoff);
        }
        RScriptExecutor executor = new RScriptExecutor();
        executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class));
        executor.addArgs(this.TRANCHES_FILE.getAbsoluteFile(), this.TARGET_TITV);
        logger.info("Executing: " + executor.getApproximateCommandLine());
        executor.exec();
    }

    private void createVisualizationScript(ExpandingArrayList<VariantDatum> randomData, GaussianMixtureModel goodModel, GaussianMixtureModel badModel, double lodCutoff) {
        PrintStream stream;
        try {
            stream = new PrintStream(this.RSCRIPT_FILE);
        }
        catch (FileNotFoundException e) {
            throw new UserException.CouldNotCreateOutputFile(this.RSCRIPT_FILE, (Exception)e);
        }
        stream.println("library(ggplot2)");
        stream.println("library(tools)");
        stream.println("library(grid)");
        this.createArrangeFunction(stream);
        stream.println("outputPDF <- \"" + this.RSCRIPT_FILE + ".pdf\"");
        stream.println("pdf(outputPDF)");
        for (int iii = 0; iii < this.USE_ANNOTATIONS.length; ++iii) {
            for (int jjj = iii + 1; jjj < this.USE_ANNOTATIONS.length; ++jjj) {
                logger.info("Building " + this.USE_ANNOTATIONS[iii] + " x " + this.USE_ANNOTATIONS[jjj] + " plot...");
                ExpandingArrayList<VariantDatum> fakeData = new ExpandingArrayList<VariantDatum>();
                double minAnn1 = 100.0;
                double maxAnn1 = -100.0;
                double minAnn2 = 100.0;
                double maxAnn2 = -100.0;
                for (VariantDatum datum : randomData) {
                    minAnn1 = Math.min(minAnn1, datum.annotations[iii]);
                    maxAnn1 = Math.max(maxAnn1, datum.annotations[iii]);
                    minAnn2 = Math.min(minAnn2, datum.annotations[jjj]);
                    maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]);
                }
                for (double ann1 = minAnn1; ann1 <= maxAnn1; ann1 += 0.1) {
                    for (double ann2 = minAnn2; ann2 <= maxAnn2; ann2 += 0.1) {
                        VariantDatum datum = new VariantDatum();
                        datum.prior = 0.0;
                        datum.annotations = new double[randomData.get((int)0).annotations.length];
                        datum.isNull = new boolean[randomData.get((int)0).annotations.length];
                        for (int ann = 0; ann < datum.annotations.length; ++ann) {
                            datum.annotations[ann] = 0.0;
                            datum.isNull[ann] = true;
                        }
                        datum.annotations[iii] = ann1;
                        datum.annotations[jjj] = ann2;
                        datum.isNull[iii] = false;
                        datum.isNull[jjj] = false;
                        fakeData.add(datum);
                    }
                }
                this.engine.evaluateData(fakeData, goodModel, false);
                this.engine.evaluateData(fakeData, badModel, true);
                stream.print("surface <- c(");
                for (VariantDatum datum : fakeData) {
                    stream.print(String.format("%.3f, %.3f, %.3f, ", datum.annotations[iii], datum.annotations[jjj], Math.min(4.0, Math.max(-4.0, datum.lod))));
                }
                stream.println("NA,NA,NA)");
                stream.println("s <- matrix(surface,ncol=3,byrow=T)");
                stream.print("data <- c(");
                for (VariantDatum datum : randomData) {
                    stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], datum.lod < lodCutoff ? -1.0 : 1.0, datum.atAntiTrainingSite ? -1 : (datum.atTrainingSite ? 1 : 0), datum.isKnown ? 1 : -1));
                }
                stream.println("NA,NA,NA,NA,1)");
                stream.println("d <- matrix(data,ncol=5,byrow=T)");
                String surfaceFrame = "sf." + this.USE_ANNOTATIONS[iii] + "." + this.USE_ANNOTATIONS[jjj];
                String dataFrame = "df." + this.USE_ANNOTATIONS[iii] + "." + this.USE_ANNOTATIONS[jjj];
                stream.println(surfaceFrame + " <- data.frame(x=s[,1], y=s[,2], lod=s[,3])");
                stream.println(dataFrame + " <- data.frame(x=d[,1], y=d[,2], retained=d[,3], training=d[,4], novelty=d[,5])");
                stream.println("dummyData <- " + dataFrame + "[1,]");
                stream.println("dummyData$x <- NaN");
                stream.println("dummyData$y <- NaN");
                stream.println("p <- ggplot(data=" + surfaceFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
                stream.println("p1 = p + opts(title=\"model PDF\") + labs(x=\"" + this.USE_ANNOTATIONS[iii] + "\", y=\"" + this.USE_ANNOTATIONS[jjj] + "\") + geom_tile(aes(fill = lod)) + scale_fill_gradient(high=\"green\", low=\"red\")");
                stream.println("p <- qplot(x,y,data=" + dataFrame + ", color=retained, alpha=I(1/7),legend=FALSE) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
                stream.println("q <- geom_point(aes(x=x,y=y,color=retained),data=dummyData, alpha=1.0, na.rm=TRUE)");
                stream.println("p2 = p + q + labs(x=\"" + this.USE_ANNOTATIONS[iii] + "\", y=\"" + this.USE_ANNOTATIONS[jjj] + "\") + scale_colour_gradient(name=\"outcome\", high=\"black\", low=\"red\",breaks=c(-1,1),labels=c(\"filtered\",\"retained\"))");
                stream.println("p <- qplot(x,y,data=" + dataFrame + "[" + dataFrame + "$training != 0,], color=training, alpha=I(1/7)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
                stream.println("q <- geom_point(aes(x=x,y=y,color=training),data=dummyData, alpha=1.0, na.rm=TRUE)");
                stream.println("p3 = p + q + labs(x=\"" + this.USE_ANNOTATIONS[iii] + "\", y=\"" + this.USE_ANNOTATIONS[jjj] + "\") + scale_colour_gradient(high=\"green\", low=\"purple\",breaks=c(-1,1), labels=c(\"neg\", \"pos\"))");
                stream.println("p <- qplot(x,y,data=" + dataFrame + ", color=novelty, alpha=I(1/7)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
                stream.println("q <- geom_point(aes(x=x,y=y,color=novelty),data=dummyData, alpha=1.0, na.rm=TRUE)");
                stream.println("p4 = p + q + labs(x=\"" + this.USE_ANNOTATIONS[iii] + "\", y=\"" + this.USE_ANNOTATIONS[jjj] + "\") + scale_colour_gradient(name=\"novelty\", high=\"blue\", low=\"red\",breaks=c(-1,1), labels=c(\"novel\",\"known\"))");
                stream.println("arrange(p1, p2, p3, p4, ncol=2)");
            }
        }
        stream.println("dev.off()");
        stream.println("if (exists(\"compactPDF\")) {");
        stream.println("compactPDF(outputPDF)");
        stream.println("}");
        stream.close();
        RScriptExecutor executor = new RScriptExecutor();
        executor.addScript(this.RSCRIPT_FILE);
        logger.info("Executing: " + executor.getApproximateCommandLine());
        executor.exec();
    }

    private void createArrangeFunction(PrintStream stream) {
        stream.println("vp.layout <- function(x, y) viewport(layout.pos.row=x, layout.pos.col=y)");
        stream.println("arrange <- function(..., nrow=NULL, ncol=NULL, as.table=FALSE) {");
        stream.println("dots <- list(...)");
        stream.println("n <- length(dots)");
        stream.println("if(is.null(nrow) & is.null(ncol)) { nrow = floor(n/2) ; ncol = ceiling(n/nrow)}");
        stream.println("if(is.null(nrow)) { nrow = ceiling(n/ncol)}");
        stream.println("if(is.null(ncol)) { ncol = ceiling(n/nrow)}");
        stream.println("grid.newpage()");
        stream.println("pushViewport(viewport(layout=grid.layout(nrow,ncol) ) )");
        stream.println("ii.p <- 1");
        stream.println("for(ii.row in seq(1, nrow)){");
        stream.println("ii.table.row <- ii.row ");
        stream.println("if(as.table) {ii.table.row <- nrow - ii.table.row + 1}");
        stream.println("for(ii.col in seq(1, ncol)){");
        stream.println("ii.table <- ii.p");
        stream.println("if(ii.p > n) break");
        stream.println("print(dots[[ii.table]], vp=vp.layout(ii.table.row, ii.col))");
        stream.println("ii.p <- ii.p + 1");
        stream.println("}");
        stream.println("}");
        stream.println("}");
    }
}

