/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.sql.calcite.rule;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.extraction.ExtractionFn;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.lookup.LookupExtractionFn;
import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.sql.calcite.expression.builtin.MultiValueStringOperatorConversions;
import org.apache.druid.sql.calcite.expression.builtin.QueryLookupOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.SearchOperatorConversion;
import org.apache.druid.sql.calcite.filtration.CollectComparisons;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.utils.CollectionUtils;

public class ReverseLookupRule
extends RelOptRule
implements SubstitutionRule {
    public static final String CTX_MAX_OPTIMIZE_COUNT = "maxOptimizeCountForDruidReverseLookupRule";
    public static final String CTX_THRESHOLD = "sqlReverseLookupThreshold";
    public static final int DEFAULT_THRESHOLD = 10000;
    private final PlannerContext plannerContext;

    public ReverseLookupRule(PlannerContext plannerContext) {
        super(ReverseLookupRule.operand(LogicalFilter.class, (RelOptRuleOperandChildren)ReverseLookupRule.any()));
        this.plannerContext = plannerContext;
    }

    public void onMatch(RelOptRuleCall call) {
        Filter filter = (Filter)call.rel(0);
        int maxOptimizeCount = this.plannerContext.queryContext().getInt(CTX_MAX_OPTIMIZE_COUNT, Integer.MAX_VALUE);
        int maxInSize = Math.min(this.plannerContext.queryContext().getInSubQueryThreshold(), this.plannerContext.queryContext().getInt(CTX_THRESHOLD, 10000));
        ReverseLookupShuttle reverseLookupShuttle = new ReverseLookupShuttle(this.plannerContext, filter.getCluster().getRexBuilder(), maxOptimizeCount, maxInSize);
        RexNode newCondition = (RexNode)filter.getCondition().accept((RexVisitor)reverseLookupShuttle);
        if (newCondition != filter.getCondition()) {
            call.transformTo(call.builder().push(filter.getInput()).filter(new RexNode[]{newCondition}).build());
            call.getPlanner().prune((RelNode)filter);
        }
    }

    private static List<RexNode> stringsToRexNodes(Iterable<String> strings, RexBuilder rexBuilder) {
        return Lists.newArrayList((Iterable)Iterables.transform(strings, s -> {
            if (s == null) {
                return rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR));
            }
            return rexBuilder.makeLiteral(s);
        }));
    }

    private static boolean isBinaryComparison(RexNode rexNode) {
        if (rexNode instanceof RexCall) {
            RexCall call = (RexCall)rexNode;
            return call.getKind() == SqlKind.EQUALS || call.getKind() == SqlKind.NOT_EQUALS || call.getOperator().equals((Object)MultiValueStringOperatorConversions.CONTAINS.calciteOperator()) || call.getOperator().equals((Object)MultiValueStringOperatorConversions.OVERLAP.calciteOperator()) || call.getOperator().equals((Object)ScalarInArrayOperatorConversion.SQL_FUNCTION);
        }
        return false;
    }

    static boolean isLookupCall(RexNode expr) {
        return expr.isA(SqlKind.OTHER_FUNCTION) && ((RexCall)expr).getOperator().equals((Object)QueryLookupOperatorConversion.SQL_FUNCTION);
    }

    @Nullable
    private static Set<String> toStringSet(RexNode literal, boolean matchNulls) {
        if (RexUtil.isNullLiteral((RexNode)literal, (boolean)true)) {
            return matchNulls ? Collections.singleton(null) : Collections.emptySet();
        }
        if (SqlTypeFamily.STRING.contains(literal.getType())) {
            String s = RexLiteral.stringValue((RexNode)literal);
            return s != null || matchNulls ? Collections.singleton(s) : Collections.emptySet();
        }
        if (literal.getType().getSqlTypeName() == SqlTypeName.ARRAY && SqlTypeFamily.STRING.contains(literal.getType().getComponentType())) {
            HashSet<String> elements = new HashSet<String>();
            for (RexNode element : ((RexCall)literal).getOperands()) {
                String s = RexLiteral.stringValue((RexNode)element);
                if (s == null && !matchNulls) continue;
                elements.add(s);
            }
            return elements;
        }
        return null;
    }

    private static class ReverseLookupKey {
        private final RexNode arg;
        private final String lookupName;
        private final String replaceMissingValueWith;
        private final boolean multiValue;
        private final boolean negate;

        private ReverseLookupKey(RexNode arg, String lookupName, String replaceMissingValueWith, boolean multiValue, boolean negate) {
            this.arg = arg;
            this.lookupName = lookupName;
            this.replaceMissingValueWith = replaceMissingValueWith;
            this.multiValue = multiValue;
            this.negate = negate;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ReverseLookupKey that = (ReverseLookupKey)o;
            return this.multiValue == that.multiValue && this.negate == that.negate && Objects.equals(this.arg, that.arg) && Objects.equals(this.lookupName, that.lookupName) && Objects.equals(this.replaceMissingValueWith, that.replaceMissingValueWith);
        }

        public int hashCode() {
            return Objects.hash(this.arg, this.lookupName, this.replaceMissingValueWith, this.multiValue, this.negate);
        }
    }

    static class ReverseLookupShuttle
    extends RexShuttle {
        private final PlannerContext plannerContext;
        private final RexBuilder rexBuilder;
        private final int maxOptimizeCount;
        private final int maxInSize;
        private final Set<RexNode> consideredAsChild = new HashSet<RexNode>();
        private boolean includeUnknown = false;
        private int optimizeCount = 0;

        public ReverseLookupShuttle(PlannerContext plannerContext, RexBuilder rexBuilder, int maxOptimizeCount, int maxInSize) {
            this.plannerContext = plannerContext;
            this.rexBuilder = rexBuilder;
            this.maxOptimizeCount = maxOptimizeCount;
            this.maxInSize = maxInSize;
        }

        public RexNode visitCall(RexCall call) {
            if (call.getKind() == SqlKind.NOT) {
                return this.visitNot(call);
            }
            if (call.getKind() == SqlKind.AND) {
                return this.visitAnd(call);
            }
            if (call.getKind() == SqlKind.OR) {
                return this.visitOr(call);
            }
            if (call.isA(SqlKind.SEARCH)) {
                return this.visitSearch(call);
            }
            if ((call.isA(SqlKind.IS_NULL) || ReverseLookupRule.isBinaryComparison((RexNode)call)) && !this.consideredAsChild.contains(call)) {
                return this.visitComparison(call);
            }
            return super.visitCall(call);
        }

        private RexNode visitNot(RexCall call) {
            this.includeUnknown = NullHandling.useThreeValueLogic() && !this.includeUnknown;
            RexNode retVal = super.visitCall(call);
            this.includeUnknown = NullHandling.useThreeValueLogic() && !this.includeUnknown;
            return retVal;
        }

        private RexNode visitOr(RexCall call) {
            this.consideredAsChild.addAll(call.getOperands());
            List newOperands = new CollectReverseLookups(call.getOperands(), this.rexBuilder).collect();
            if (newOperands != call.getOperands()) {
                return RexUtil.composeDisjunction((RexBuilder)this.rexBuilder, newOperands);
            }
            return super.visitCall(call);
        }

        private RexNode visitAnd(RexCall call) {
            ArrayList<RexNode> notOrs = new ArrayList<RexNode>();
            ArrayList<RexNode> remainder = new ArrayList<RexNode>();
            for (RexNode operand : call.getOperands()) {
                if (operand.isA(SqlKind.NOT)) {
                    RexNode nodeBeneathNot = (RexNode)Iterables.getOnlyElement((Iterable)((RexCall)operand).getOperands());
                    this.consideredAsChild.add(nodeBeneathNot);
                    notOrs.add(nodeBeneathNot);
                    continue;
                }
                if (operand.isA(SqlKind.NOT_EQUALS)) {
                    this.consideredAsChild.add(operand);
                    notOrs.add(this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, ((RexCall)operand).getOperands()));
                    continue;
                }
                if (operand.isA(SqlKind.IS_NOT_NULL)) {
                    this.consideredAsChild.add(operand);
                    notOrs.add(this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NULL, ((RexCall)operand).getOperands()));
                    continue;
                }
                remainder.add(operand);
            }
            if (!notOrs.isEmpty()) {
                this.includeUnknown = !this.includeUnknown;
                List newNotOrs = new CollectReverseLookups(notOrs, this.rexBuilder).collect();
                boolean bl = this.includeUnknown = !this.includeUnknown;
                if (newNotOrs != notOrs) {
                    RexNode retVal = this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.NOT, new RexNode[]{RexUtil.composeDisjunction((RexBuilder)this.rexBuilder, newNotOrs)});
                    if (!remainder.isEmpty()) {
                        remainder.add(retVal);
                        retVal = this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, remainder);
                    }
                    return retVal;
                }
            }
            return super.visitCall(call);
        }

        private RexNode visitSearch(RexCall call) {
            RexNode converted;
            RexNode expanded = SearchOperatorConversion.expandSearch(call, this.rexBuilder, this.plannerContext.queryContext().getInFunctionThreshold());
            if (expanded instanceof RexCall && (converted = this.visitCall((RexCall)expanded)) != expanded) {
                return converted;
            }
            return call;
        }

        private RexNode visitComparison(RexCall call) {
            RexNode retVal = (RexNode)CollectionUtils.getOnlyElement(new CollectReverseLookups(Collections.singletonList(call), this.rexBuilder).collect(), ret -> new ISE("Expected to collect single node, got[%s]", new Object[]{ret}));
            if (retVal != call) {
                return retVal;
            }
            return super.visitCall(call);
        }

        private class CollectReverseLookups
        extends CollectComparisons<RexNode, RexCall, RexNode, ReverseLookupKey, String, InDimFilter.ValuesSet> {
            private final RexBuilder rexBuilder;

            private CollectReverseLookups(List<RexNode> orExprs, RexBuilder rexBuilder) {
                super(orExprs);
                this.rexBuilder = rexBuilder;
            }

            @Override
            @Nullable
            protected Pair<RexCall, List<RexNode>> getCollectibleComparison(RexNode expr) {
                RexCall asLookupComparison = this.getAsLookupComparison(expr);
                if (asLookupComparison != null) {
                    return Pair.of((Object)asLookupComparison, Collections.emptyList());
                }
                return null;
            }

            @Override
            protected InDimFilter.ValuesSet makeCollection() {
                return new InDimFilter.ValuesSet();
            }

            @Override
            @Nullable
            protected ReverseLookupKey getCollectionKey(RexCall call) {
                boolean isComparisonAgainstReplaceMissingValueWith;
                RexCall lookupCall = (RexCall)call.getOperands().get(0);
                List lookupOperands = lookupCall.getOperands();
                RexNode argument = (RexNode)lookupOperands.get(0);
                String lookupName = RexLiteral.stringValue((RexNode)((RexNode)lookupOperands.get(1)));
                String replaceMissingValueWith = lookupOperands.size() >= 3 ? NullHandling.emptyToNullIfNeeded((String)RexLiteral.stringValue((RexNode)((RexNode)lookupOperands.get(2)))) : null;
                LookupExtractor lookup = ReverseLookupShuttle.this.plannerContext.getLookup(lookupName);
                if (lookup == null) {
                    return null;
                }
                if (!lookup.isOneToOne() && (isComparisonAgainstReplaceMissingValueWith = replaceMissingValueWith == null ? call.isA(SqlKind.IS_NULL) : this.getMatchValues(call).contains(replaceMissingValueWith))) {
                    return null;
                }
                boolean multiValue = call.getOperator().equals((Object)MultiValueStringOperatorConversions.CONTAINS.calciteOperator()) || call.getOperator().equals((Object)MultiValueStringOperatorConversions.OVERLAP.calciteOperator());
                boolean negate = call.getKind() == SqlKind.NOT_EQUALS;
                return new ReverseLookupKey(argument, lookupName, replaceMissingValueWith, multiValue, negate);
            }

            @Override
            protected Set<String> getMatchValues(RexCall call) {
                if (call.isA(SqlKind.IS_NULL)) {
                    return Collections.singleton(null);
                }
                RexNode matchLiteral = (RexNode)call.getOperands().get(1);
                boolean matchNulls = call.getOperator().equals((Object)MultiValueStringOperatorConversions.CONTAINS.calciteOperator()) || call.getOperator().equals((Object)MultiValueStringOperatorConversions.OVERLAP.calciteOperator()) || call.getOperator().equals((Object)ScalarInArrayOperatorConversion.SQL_FUNCTION);
                return ReverseLookupRule.toStringSet(matchLiteral, matchNulls);
            }

            @Override
            @Nullable
            protected RexNode makeCollectedComparison(ReverseLookupKey reverseLookupKey, InDimFilter.ValuesSet matchValues) {
                Set<String> reversedMatchValues;
                LookupExtractor lookupExtractor = ReverseLookupShuttle.this.plannerContext.getLookup(reverseLookupKey.lookupName);
                if (lookupExtractor != null && (reversedMatchValues = this.reverseLookup(lookupExtractor, reverseLookupKey.replaceMissingValueWith, matchValues, ReverseLookupShuttle.this.includeUnknown ^ reverseLookupKey.negate)) != null) {
                    return this.makeMatchCondition(reverseLookupKey, reversedMatchValues, this.rexBuilder);
                }
                return null;
            }

            @Override
            protected RexNode makeAnd(List<RexNode> exprs) {
                throw new UnsupportedOperationException();
            }

            @Nullable
            private RexCall getAsLookupComparison(RexNode expr) {
                if (expr.isA(SqlKind.IS_NULL) && ReverseLookupRule.isLookupCall((RexNode)((RexCall)expr).getOperands().get(0))) {
                    return (RexCall)expr;
                }
                if (!ReverseLookupRule.isBinaryComparison(expr)) {
                    return null;
                }
                RexCall call = (RexCall)expr;
                RexNode lookupCall = (RexNode)call.getOperands().get(0);
                RexNode literal = (RexNode)call.getOperands().get(1);
                if (literal instanceof RexCall && Calcites.isLiteral(lookupCall, true, true)) {
                    if (call.getOperator().equals((Object)MultiValueStringOperatorConversions.CONTAINS.calciteOperator())) {
                        return null;
                    }
                    RexNode tmp = lookupCall;
                    lookupCall = literal;
                    literal = tmp;
                }
                lookupCall = RexUtil.removeNullabilityCast((RelDataTypeFactory)this.rexBuilder.getTypeFactory(), (RexNode)lookupCall);
                literal = RexUtil.removeNullabilityCast((RelDataTypeFactory)this.rexBuilder.getTypeFactory(), (RexNode)literal);
                if (ReverseLookupRule.isLookupCall(lookupCall) && Calcites.isLiteral(literal, true, true)) {
                    return (RexCall)this.rexBuilder.makeCall(call.getOperator(), new RexNode[]{lookupCall, literal});
                }
                return null;
            }

            @Nullable
            private Set<String> reverseLookup(LookupExtractor lookupExtractor, @Nullable String replaceMissingValueWith, InDimFilter.ValuesSet matchValues, boolean mayIncludeUnknown) {
                ReverseLookupShuttle.this.optimizeCount++;
                if (ReverseLookupShuttle.this.optimizeCount > ReverseLookupShuttle.this.maxOptimizeCount) {
                    throw new ISE("Too many optimize calls[%s]", new Object[]{ReverseLookupShuttle.this.optimizeCount});
                }
                InDimFilter filterToOptimize = new InDimFilter("__dummy__", (Collection)matchValues, (ExtractionFn)new LookupExtractionFn(lookupExtractor, false, replaceMissingValueWith, null, Boolean.valueOf(true)));
                return InDimFilter.optimizeLookup((InDimFilter)filterToOptimize, (mayIncludeUnknown && NullHandling.useThreeValueLogic() ? 1 : 0) != 0, (int)ReverseLookupShuttle.this.maxInSize);
            }

            private RexNode makeMatchCondition(ReverseLookupKey reverseLookupKey, Set<String> reversedMatchValues, RexBuilder rexBuilder) {
                if (reversedMatchValues.isEmpty()) {
                    return rexBuilder.makeLiteral(reverseLookupKey.negate);
                }
                if (reverseLookupKey.multiValue) {
                    RexNode condition = reversedMatchValues.size() == 1 ? rexBuilder.makeCall(MultiValueStringOperatorConversions.CONTAINS.calciteOperator(), new RexNode[]{reverseLookupKey.arg, (RexNode)Iterables.getOnlyElement((Iterable)ReverseLookupRule.stringsToRexNodes(reversedMatchValues, rexBuilder))}) : rexBuilder.makeCall(MultiValueStringOperatorConversions.OVERLAP.calciteOperator(), new RexNode[]{reverseLookupKey.arg, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, ReverseLookupRule.stringsToRexNodes(reversedMatchValues, rexBuilder))});
                    if (reverseLookupKey.negate) {
                        condition = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.NOT, new RexNode[]{condition});
                    }
                    return condition;
                }
                return SearchOperatorConversion.makeIn(reverseLookupKey.arg, reversedMatchValues, rexBuilder.getTypeFactory().createTypeWithNullability(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), true), reverseLookupKey.negate, reversedMatchValues.size() >= ReverseLookupShuttle.this.plannerContext.queryContext().getInFunctionThreshold(), rexBuilder);
            }
        }
    }
}

