/*
 * Decompiled with CFR 0.152.
 */
package ghidra.app.plugin.assembler.sleigh.sem;

import ghidra.app.plugin.assembler.sleigh.expr.MaskedLong;
import ghidra.app.plugin.assembler.sleigh.expr.NeedsBackfillException;
import ghidra.app.plugin.assembler.sleigh.expr.RecursiveDescentSolver;
import ghidra.app.plugin.assembler.sleigh.grammars.AssemblyGrammar;
import ghidra.app.plugin.assembler.sleigh.grammars.AssemblyProduction;
import ghidra.app.plugin.assembler.sleigh.sem.AbstractAssemblyStateGenerator;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyConstructStateGenerator;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyConstructorSemantic;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyContextGraph;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyDefaultContext;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyGeneratedPrototype;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyHiddenConstructStateGenerator;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyNopStateGenerator;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyOperandStateGenerator;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyPatternBlock;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolution;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolutionResults;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedBackfill;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedPatterns;
import ghidra.app.plugin.assembler.sleigh.symbol.AssemblyNonTerminal;
import ghidra.app.plugin.assembler.sleigh.tree.AssemblyParseBranch;
import ghidra.app.plugin.assembler.sleigh.tree.AssemblyParseNumericToken;
import ghidra.app.plugin.assembler.sleigh.tree.AssemblyParseTreeNode;
import ghidra.app.plugin.assembler.sleigh.util.DbgTimer;
import ghidra.app.plugin.processors.sleigh.Constructor;
import ghidra.app.plugin.processors.sleigh.SleighInstructionPrototype;
import ghidra.app.plugin.processors.sleigh.SleighLanguage;
import ghidra.app.plugin.processors.sleigh.expression.PatternExpression;
import ghidra.app.plugin.processors.sleigh.symbol.OperandSymbol;
import ghidra.app.plugin.processors.sleigh.symbol.SubtableSymbol;
import ghidra.app.plugin.processors.sleigh.symbol.TripleSymbol;
import ghidra.program.model.address.Address;
import ghidra.program.model.lang.InsufficientBytesException;
import ghidra.program.model.lang.UnknownInstructionException;
import ghidra.program.model.mem.ByteMemBufferImpl;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class AssemblyTreeResolver {
    protected static final RecursiveDescentSolver SOLVER = RecursiveDescentSolver.getSolver();
    protected static final DbgTimer DBG = DbgTimer.INACTIVE;
    public static final String INST_START = "inst_start";
    public static final String INST_NEXT = "inst_next";
    public static final String INST_NEXT2 = "inst_next2";
    protected final SleighLanguage lang;
    protected final Address at;
    protected final Map<String, Long> vals = new HashMap<String, Long>();
    protected final AssemblyParseBranch tree;
    protected final AssemblyGrammar grammar;
    protected final AssemblyPatternBlock context;
    protected final AssemblyContextGraph ctxGraph;

    public AssemblyTreeResolver(SleighLanguage lang, Address at, AssemblyParseBranch tree, AssemblyPatternBlock context, AssemblyContextGraph ctxGraph) {
        this.lang = lang;
        this.at = at;
        this.vals.put(INST_START, at.getAddressableWordOffset());
        this.tree = tree;
        this.grammar = tree.getGrammar();
        this.context = context.fillMask();
        this.ctxGraph = ctxGraph;
    }

    public AssemblyResolutionResults resolve() {
        AssemblyResolvedPatterns empty = AssemblyResolution.nop("Empty");
        AssemblyConstructStateGenerator rootGen = new AssemblyConstructStateGenerator(this, this.tree, empty);
        ArrayList errors = new ArrayList();
        Stream<AssemblyGeneratedPrototype> protStream = rootGen.generate(new AbstractAssemblyStateGenerator.GeneratorContext(List.of(), 0));
        if (DBG == DbgTimer.ACTIVE) {
            try (DbgTimer.DbgCtx dc = DBG.start("Prototypes:");){
                protStream = protStream.map(prot -> {
                    DBG.println(prot);
                    return prot;
                }).collect(Collectors.toList()).stream();
            }
        }
        Stream patStream = protStream.map(p -> p.state).distinct().flatMap(s -> s.resolve(empty, errors));
        AssemblyResolutionResults results = new AssemblyResolutionResults();
        patStream.forEach(results::add);
        results = this.resolveRootRecursion(results);
        results = this.selectContext(results);
        results = this.resolvePendingBackfills(results);
        results = this.filterForbidden(results);
        results = this.filterByDisassembly(results);
        results.addAll(errors);
        return results;
    }

    protected AssemblyProduction getRootRecursion() {
        assert (this.tree.getParent() == null);
        AssemblyProduction rootProd = this.tree.getProduction();
        Object start = rootProd.getLHS();
        AssemblyProduction rec = this.grammar.getPureRecursion((AssemblyNonTerminal)start);
        return rec;
    }

    public AssemblyResolutionResults resolveRootRecursion(AssemblyResolutionResults temp) {
        AssemblyProduction rootRec = this.getRootRecursion();
        if (rootRec == null) {
            return temp;
        }
        try (DbgTimer.DbgCtx dc = DBG.start("Resolving root recursion:");){
            AssemblyResolutionResults result = new AssemblyResolutionResults();
            Object object = temp.iterator();
            while (object.hasNext()) {
                AssemblyResolution ar = (AssemblyResolution)object.next();
                if (ar.isError()) {
                    result.add(ar);
                    continue;
                }
                AssemblyResolvedPatterns rc = (AssemblyResolvedPatterns)ar;
                AssemblyPatternBlock dst = rc.getContext();
                AssemblyPatternBlock src = this.context;
                String table = "instruction";
                DBG.println("Finding paths from " + src + " to " + ar.lineToString());
                Collection<Deque<AssemblyConstructorSemantic>> paths = this.ctxGraph.computeOptimalApplications(src, table, dst, table);
                DBG.println("Found " + paths.size());
                for (Deque<AssemblyConstructorSemantic> path : paths) {
                    DBG.println("  " + path);
                    result.absorb(this.applyRecursionPath(path, this.tree, rootRec, rc));
                }
            }
            object = result;
            return object;
        }
    }

    protected AssemblyResolutionResults resolvePendingBackfills(AssemblyResolutionResults temp) {
        return temp.apply(rc -> {
            if (!rc.hasBackfills()) {
                return rc;
            }
            this.vals.put(INST_NEXT, this.at.add(rc.getInstructionLength()).getAddressableWordOffset());
            this.vals.put(INST_NEXT2, this.at.add(rc.getInstructionLength()).getAddressableWordOffset());
            DBG.println("Backfilling: " + rc);
            AssemblyResolution ar = rc.backfill(SOLVER, this.vals);
            DBG.println("Backfilled final: " + ar);
            return ar;
        }).apply(rc -> {
            if (rc.hasBackfills()) {
                return AssemblyResolution.error("Solution is incomplete", "failed backfill", List.of(rc), null);
            }
            return rc;
        });
    }

    protected AssemblyResolutionResults selectContext(AssemblyResolutionResults temp) {
        AssemblyResolvedPatterns ctx = AssemblyResolution.contextOnly(this.context, "Selecting context");
        return temp.apply(rc -> {
            AssemblyResolvedPatterns check = rc.combine(ctx);
            if (null == check) {
                return AssemblyResolution.error("Incompatible context", "resolving", List.of(rc), null);
            }
            return check;
        });
    }

    protected AssemblyResolutionResults filterForbidden(AssemblyResolutionResults temp) {
        return temp.apply(rc -> rc.checkNotForbidden());
    }

    protected AssemblyResolutionResults filterByDisassembly(AssemblyResolutionResults temp) {
        AssemblyDefaultContext asmCtx = new AssemblyDefaultContext(this.lang);
        asmCtx.setContextRegister(this.context);
        return temp.apply(rc -> {
            ByteMemBufferImpl buf = new ByteMemBufferImpl(this.at, rc.getInstruction().getVals(), this.lang.isBigEndian());
            try {
                SleighInstructionPrototype ip = (SleighInstructionPrototype)this.lang.parse(buf, asmCtx, false);
                if (!rc.equivalentConstructState(ip.getRootState())) {
                    return AssemblyResolution.error("Disassembly prototype mismatch", rc);
                }
                return rc;
            }
            catch (InsufficientBytesException | UnknownInstructionException e) {
                return AssemblyResolution.error("Disassembly failed: " + e.getMessage(), rc);
            }
        });
    }

    protected AbstractAssemblyStateGenerator<?> getStateGenerator(OperandSymbol opSym, AssemblyParseTreeNode node, AssemblyResolvedPatterns fromLeft) {
        if (node == null) {
            return this.getHiddenStateGenerator(opSym, fromLeft);
        }
        if (node.isNumeric()) {
            return new AssemblyOperandStateGenerator(this, (AssemblyParseNumericToken)node, opSym, fromLeft);
        }
        if (node.isConstructor()) {
            return new AssemblyConstructStateGenerator(this, (AssemblyParseBranch)node, fromLeft);
        }
        throw new AssertionError();
    }

    protected AbstractAssemblyStateGenerator<?> getHiddenStateGenerator(OperandSymbol opSym, AssemblyResolvedPatterns fromLeft) {
        TripleSymbol defSym = opSym.getDefiningSymbol();
        if (defSym instanceof SubtableSymbol) {
            return new AssemblyHiddenConstructStateGenerator(this, (SubtableSymbol)defSym, fromLeft);
        }
        return new AssemblyNopStateGenerator(this, opSym, fromLeft);
    }

    protected AssemblyResolutionResults resolvePatterns(AssemblyConstructorSemantic sem, int shift, AssemblyResolutionResults fromChildren) {
        AssemblyResolutionResults results = fromChildren;
        results = this.applyMutations(sem, results);
        results = this.applyPatterns(sem, shift, results);
        results = this.tryResolveBackfills(results);
        return results;
    }

    protected AssemblyResolutionResults parent(String description, AssemblyResolutionResults temp, int opCount) {
        return temp.stream().map(r -> r.parent(description, opCount)).collect(Collectors.toCollection(AssemblyResolutionResults::new));
    }

    protected AssemblyResolutionResults applyMutations(AssemblyConstructorSemantic sem, AssemblyResolutionResults temp) {
        DBG.println("Applying context mutations:");
        return temp.apply(rc -> {
            DBG.println("Current: " + rc.lineToString());
            AssemblyResolution backctx = sem.solveContextChanges((AssemblyResolvedPatterns)rc, this.vals);
            DBG.println("Mutated: " + backctx.lineToString());
            return backctx;
        }).apply(rc -> rc.solveContextChangesForForbids(sem, this.vals));
    }

    protected AssemblyResolutionResults applyPatterns(AssemblyConstructorSemantic sem, int shift, AssemblyResolutionResults temp) {
        DBG.println("Applying patterns:");
        final Collection patterns = sem.getPatterns().stream().map(p -> p.shift(shift)).collect(Collectors.toList());
        return temp.apply(new AssemblyResolutionResults.Applicator(){

            @Override
            public Iterable<? extends AssemblyResolution> getPatterns(AssemblyResolvedPatterns cur) {
                return patterns;
            }

            @Override
            public AssemblyResolvedPatterns setRight(AssemblyResolvedPatterns res, AssemblyResolvedPatterns cur) {
                return res;
            }

            @Override
            public String describeError(AssemblyResolvedPatterns rc, AssemblyResolution pat) {
                return "The patterns conflict " + pat.lineToString();
            }

            @Override
            public AssemblyResolvedPatterns combineBackfill(AssemblyResolvedPatterns cur, AssemblyResolvedBackfill bf) {
                throw new AssertionError();
            }

            @Override
            public AssemblyResolution finish(AssemblyResolvedPatterns resolved) {
                return resolved.checkNotForbidden();
            }
        });
    }

    protected AssemblyResolutionResults applyRecursionPath(Deque<AssemblyConstructorSemantic> path, AssemblyParseBranch branch, AssemblyProduction rec, AssemblyResolvedPatterns child) {
        AssemblyResolutionResults results = new AssemblyResolutionResults();
        results.add(child);
        while (!path.isEmpty()) {
            AssemblyConstructorSemantic sem = path.pollLast();
            int opIdx = sem.getOperandIndex(0);
            Constructor cons = sem.getConstructor();
            OperandSymbol opSym = cons.getOperand(opIdx);
            if (-1 != opSym.getOffsetBase()) {
                throw new AssertionError((Object)"TODO");
            }
            int offset = opSym.getRelativeOffset();
            results = this.parent("Resolving recursive constructor: " + cons.getSourceFile() + ":" + cons.getLineno(), results, 1);
            results = results.apply(rc -> rc.shift(offset));
            results = this.resolvePatterns(sem, 0, results).apply(rc -> rc.withConstructor(cons));
        }
        return results;
    }

    protected AssemblyResolutionResults tryResolveBackfills(AssemblyResolutionResults results) {
        AssemblyResolutionResults res = new AssemblyResolutionResults();
        Iterator iterator = results.iterator();
        block0: while (iterator.hasNext()) {
            AssemblyResolvedPatterns rc;
            AssemblyResolution ar = (AssemblyResolution)iterator.next();
            if (ar.isError()) {
                res.add(ar);
                continue;
            }
            do {
                if (!(rc = (AssemblyResolvedPatterns)ar).hasBackfills()) {
                    res.add(rc);
                    continue block0;
                }
                ar = rc.backfill(SOLVER, this.vals);
                if (!ar.isError() && !ar.isBackfill()) continue;
                res.add(ar);
                continue block0;
            } while (!ar.equals(rc));
            res.add(ar);
        }
        return res;
    }

    public static int computeOffset(OperandSymbol opsym, Constructor cons) {
        int offset = opsym.getRelativeOffset();
        int baseidx = opsym.getOffsetBase();
        if (baseidx != -1) {
            OperandSymbol baseop = cons.getOperand(baseidx);
            offset += baseop.getMinimumLength();
            offset += AssemblyTreeResolver.computeOffset(baseop, cons);
        }
        return offset;
    }

    protected static AssemblyResolution solveOrBackfill(PatternExpression exp, MaskedLong goal, Map<String, Long> vals, AssemblyResolvedPatterns cur, String description) {
        try {
            return SOLVER.solve(exp, goal, vals, cur, description);
        }
        catch (NeedsBackfillException bf) {
            int fieldLength = SOLVER.getInstructionLength(exp);
            return AssemblyResolution.backfill(exp, goal, fieldLength, description);
        }
    }

    protected static AssemblyResolution solveOrBackfill(PatternExpression exp, long goal, Map<String, Long> vals, AssemblyResolvedPatterns cur, String description) {
        return AssemblyTreeResolver.solveOrBackfill(exp, MaskedLong.fromLong(goal), vals, cur, description);
    }

    protected static AssemblyResolution solveOrBackfill(PatternExpression exp, long goal, int bits, Map<String, Long> vals, AssemblyResolvedPatterns cur, String description) {
        long msk = bits == 0 || bits >= 64 ? -1L : -1L << bits ^ 0xFFFFFFFFFFFFFFFFL;
        return AssemblyTreeResolver.solveOrBackfill(exp, MaskedLong.fromMaskAndValue(msk, goal), vals, cur, description);
    }
}

