/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;

public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {

    @Override
    protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
        return TrainedModelDefinition.fromXContent(parser, true).build();
    }

    @Override
    protected boolean supportsUnknownFields() {
        return true;
    }

    @Override
    protected Predicate<String> getRandomFieldsExcludeFilter() {
        return field -> field.isEmpty() == false;
    }

    @Override
    protected boolean assertToXContentEquivalence() {
        return false;
    }

    public static TrainedModelDefinition.Builder createRandomBuilder(TargetType targetType) {
        int numberOfProcessors = randomIntBetween(1, 10);
        return new TrainedModelDefinition.Builder().setPreProcessors(
            randomBoolean()
                ? null
                : Stream.generate(
                    () -> randomFrom(
                        FrequencyEncodingTests.createRandom(),
                        OneHotEncodingTests.createRandom(),
                        TargetMeanEncodingTests.createRandom()
                    )
                ).limit(numberOfProcessors).collect(Collectors.toList())
        ).setTrainedModel(randomFrom(TreeTests.createRandom(targetType), EnsembleTests.createRandom(targetType)));
    }

    public static TrainedModelDefinition.Builder createRandomBuilder() {
        return createRandomBuilder(randomFrom(TargetType.values()));
    }

    public static final String ENSEMBLE_MODEL = ""
        + "{\n"
        + "  \"preprocessors\": [\n"
        + "    {\n"
        + "      \"one_hot_encoding\": {\n"
        + "        \"field\": \"col1\",\n"
        + "        \"hot_map\": {\n"
        + "          \"male\": \"col1_male\",\n"
        + "          \"female\": \"col1_female\"\n"
        + "        }\n"
        + "      }\n"
        + "    },\n"
        + "    {\n"
        + "      \"target_mean_encoding\": {\n"
        + "        \"field\": \"col2\",\n"
        + "        \"feature_name\": \"col2_encoded\",\n"
        + "        \"target_map\": {\n"
        + "          \"S\": 5.0,\n"
        + "          \"M\": 10.0,\n"
        + "          \"L\": 20\n"
        + "        },\n"
        + "        \"default_value\": 5.0\n"
        + "      }\n"
        + "    },\n"
        + "    {\n"
        + "      \"frequency_encoding\": {\n"
        + "        \"field\": \"col3\",\n"
        + "        \"feature_name\": \"col3_encoded\",\n"
        + "        \"frequency_map\": {\n"
        + "          \"none\": 0.75,\n"
        + "          \"true\": 0.10,\n"
        + "          \"false\": 0.15\n"
        + "        }\n"
        + "      }\n"
        + "    }\n"
        + "  ],\n"
        + "  \"trained_model\": {\n"
        + "    \"ensemble\": {\n"
        + "      \"feature_names\": [\n"
        + "        \"col1_male\",\n"
        + "        \"col1_female\",\n"
        + "        \"col2_encoded\",\n"
        + "        \"col3_encoded\",\n"
        + "        \"col4\"\n"
        + "      ],\n"
        + "      \"aggregate_output\": {\n"
        + "        \"weighted_sum\": {\n"
        + "          \"weights\": [\n"
        + "            0.5,\n"
        + "            0.5\n"
        + "          ]\n"
        + "        }\n"
        + "      },\n"
        + "      \"target_type\": \"regression\",\n"
        + "      \"trained_models\": [\n"
        + "        {\n"
        + "          \"tree\": {\n"
        + "            \"feature_names\": [\n"
        + "              \"col1_male\",\n"
        + "              \"col1_female\",\n"
        + "              \"col4\"\n"
        + "            ],\n"
        + "            \"tree_structure\": [\n"
        + "              {\n"
        + "                \"node_index\": 0,\n"
        + "                \"split_feature\": 0,\n"
        + "                \"split_gain\": 12.0,\n"
        + "                \"threshold\": 10.0,\n"
        + "                \"decision_type\": \"lte\",\n"
        + "                \"default_left\": true,\n"
        + "                \"left_child\": 1,\n"
        + "                \"right_child\": 2\n"
        + "              },\n"
        + "              {\n"
        + "                \"node_index\": 1,\n"
        + "                \"leaf_value\": 1\n"
        + "              },\n"
        + "              {\n"
        + "                \"node_index\": 2,\n"
        + "                \"leaf_value\": 2\n"
        + "              }\n"
        + "            ],\n"
        + "            \"target_type\": \"regression\"\n"
        + "          }\n"
        + "        },\n"
        + "        {\n"
        + "          \"tree\": {\n"
        + "            \"feature_names\": [\n"
        + "              \"col2_encoded\",\n"
        + "              \"col3_encoded\",\n"
        + "              \"col4\"\n"
        + "            ],\n"
        + "            \"tree_structure\": [\n"
        + "              {\n"
        + "                \"node_index\": 0,\n"
        + "                \"split_feature\": 0,\n"
        + "                \"split_gain\": 12.0,\n"
        + "                \"threshold\": 10.0,\n"
        + "                \"decision_type\": \"lte\",\n"
        + "                \"default_left\": true,\n"
        + "                \"left_child\": 1,\n"
        + "                \"right_child\": 2\n"
        + "              },\n"
        + "              {\n"
        + "                \"node_index\": 1,\n"
        + "                \"leaf_value\": 1\n"
        + "              },\n"
        + "              {\n"
        + "                \"node_index\": 2,\n"
        + "                \"leaf_value\": 2\n"
        + "              }\n"
        + "            ],\n"
        + "            \"target_type\": \"regression\"\n"
        + "          }\n"
        + "        }\n"
        + "      ]\n"
        + "    }\n"
        + "  }\n"
        + "}";

    public static final String TREE_MODEL = ""
        + "{\n"
        + "  \"preprocessors\": [\n"
        + "    {\n"
        + "      \"one_hot_encoding\": {\n"
        + "        \"field\": \"col1\",\n"
        + "        \"hot_map\": {\n"
        + "          \"male\": \"col1_male\",\n"
        + "          \"female\": \"col1_female\"\n"
        + "        }\n"
        + "      }\n"
        + "    },\n"
        + "    {\n"
        + "      \"target_mean_encoding\": {\n"
        + "        \"field\": \"col2\",\n"
        + "        \"feature_name\": \"col2_encoded\",\n"
        + "        \"target_map\": {\n"
        + "          \"S\": 5.0,\n"
        + "          \"M\": 10.0,\n"
        + "          \"L\": 20\n"
        + "        },\n"
        + "        \"default_value\": 5.0\n"
        + "      }\n"
        + "    },\n"
        + "    {\n"
        + "      \"frequency_encoding\": {\n"
        + "        \"field\": \"col3\",\n"
        + "        \"feature_name\": \"col3_encoded\",\n"
        + "        \"frequency_map\": {\n"
        + "          \"none\": 0.75,\n"
        + "          \"true\": 0.10,\n"
        + "          \"false\": 0.15\n"
        + "        }\n"
        + "      }\n"
        + "    }\n"
        + "  ],\n"
        + "  \"trained_model\": {\n"
        + "    \"tree\": {\n"
        + "      \"feature_names\": [\n"
        + "        \"col1_male\",\n"
        + "        \"col1_female\",\n"
        + "        \"col4\"\n"
        + "      ],\n"
        + "      \"tree_structure\": [\n"
        + "        {\n"
        + "          \"node_index\": 0,\n"
        + "          \"split_feature\": 0,\n"
        + "          \"split_gain\": 12.0,\n"
        + "          \"threshold\": 10.0,\n"
        + "          \"decision_type\": \"lte\",\n"
        + "          \"default_left\": true,\n"
        + "          \"left_child\": 1,\n"
        + "          \"right_child\": 2\n"
        + "        },\n"
        + "        {\n"
        + "          \"node_index\": 1,\n"
        + "          \"leaf_value\": 1\n"
        + "        },\n"
        + "        {\n"
        + "          \"node_index\": 2,\n"
        + "          \"leaf_value\": 2\n"
        + "        }\n"
        + "      ],\n"
        + "      \"target_type\": \"regression\"\n"
        + "    }\n"
        + "  }\n"
        + "}";

    public void testEnsembleSchemaDeserialization() throws IOException {
        XContentParser parser = XContentFactory.xContent(XContentType.JSON)
            .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ENSEMBLE_MODEL);
        TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build();
        assertThat(definition.getTrainedModel().getClass(), equalTo(Ensemble.class));
    }

    public void testTreeSchemaDeserialization() throws IOException {
        XContentParser parser = XContentFactory.xContent(XContentType.JSON)
            .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, TREE_MODEL);
        TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build();
        assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
    }

    @Override
    protected TrainedModelDefinition createTestInstance() {
        return createRandomBuilder().build();
    }

    @Override
    protected Writeable.Reader<TrainedModelDefinition> instanceReader() {
        return TrainedModelDefinition::new;
    }

    @Override
    protected NamedXContentRegistry xContentRegistry() {
        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
        namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
        return new NamedXContentRegistry(namedXContent);
    }

    @Override
    protected NamedWriteableRegistry getNamedWriteableRegistry() {
        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
        return new NamedWriteableRegistry(entries);
    }

    public void testRamUsageEstimation() {
        TrainedModelDefinition test = createTestInstance();
        assertThat(test.ramBytesUsed(), greaterThan(0L));
    }
}
