GridGain Developers Hub

GridGain ML Advanced Examples

This document provides complete, working examples of advanced GridGain ML scenarios that require custom compute jobs, translators, and marshallers.

Overview

Complete working examples are available in the examples directory. Each example demonstrates:

  1. Downloading models (manually or programmatically)

  2. Packaging user-defined input types, translators, and translator factories (if used) into a JAR

  3. Deploying models along with the JAR containing user-defined code

  4. Running predictions via Client API or Embedded API

Example 1: Custom Translator with ONNX Model

This example demonstrates zero-shot classification using an ONNX model with fully custom input/output types, translators, and marshallers.

Key features:

  • ONNX model hosted on HuggingFace

  • User-defined input type (CustomInput) with multiple fields

  • User-defined output type (CustomOutput) with structured results

  • User-defined translator factory for model-specific processing

  • User-defined marshallers for serialization

  • Demonstrates both simple and batch predictions

Step 1: Download Model

You can download the model files using either manual or automated methods.

Option A: Manual Download

Download a zero-shot classification model files from HuggingFace:

curl --ssl-no-revoke -O \
  "https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/onnx/model_quantized.onnx"

curl --ssl-no-revoke -O \
  "https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/tokenizer.json?download=true"

After download, deploy the model to GridGain:

# Deploy model to cluster
cluster unit deploy zeroshot-model \
    --version 1.0.0 \
    --path /path/to/model/directory \
    --nodes ALL

Option B: Automated Download and Deployment

private static final List<String> urls = Arrays.asList(
    "https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/onnx/model_quantized.onnx",
    "https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/tokenizer.json?download=true"
);

// Download model files and deploy to GridGain cluster
tempModelDir = ModelUtils.downloadAndDeployModelAndJar(
    urls,           // URLs to download model files from
    MODEL_ID,       // Model identifier (e.g., "zeroshot")
    MODEL_VERSION,  // Model version (e.g., "1.0.0")
    sourceDir       // Path to custom classes for JAR packaging
);

This method automatically:

  • Downloads all required model files from the specified URLs

  • Creates a temporary directory for the model files

  • Packages custom classes (translators, marshallers) into a JAR

  • Deploys both the model and custom classes to the GridGain cluster

Step 2: Implement Custom Input/Output Types

public class CustomInput implements Serializable {
    private String text;
    private String[] candidates;
    private boolean multiLabel;

    public CustomInput(String text, String[] candidates, boolean multiLabel) {
        this.text = text;
        this.candidates = candidates;
        this.multiLabel = multiLabel;
    }

    // getters and setters...
}

public class CustomOutput implements Serializable {
    private String sequence;
    private String[] labels;
    private double[] scores;

    public CustomOutput(String sequence, String[] labels, double[] scores) {
        this.sequence = sequence;
        this.labels = labels;
        this.scores = scores;
    }

    // getters and setters...
}

Step 3: Implement Custom Translator and TranslatorFactory

Custom Translator

public class CustomTranslator implements NoBatchifyTranslator<CustomInput, CustomOutput> {

    private HuggingFaceTokenizer tokenizer;
    private boolean int32;
    private Predictor<NDList, NDList> predictor;

    private CustomTranslator(HuggingFaceTokenizer tokenizer, boolean int32) {
        this.tokenizer = tokenizer;
        this.int32 = int32;
    }

    @Override
    public void prepare(TranslatorContext ctx) throws IOException, ModelException {
        Model model = ctx.getModel();
        this.predictor = model.newPredictor(new NoopTranslator((Batchifier)null));
        ctx.getPredictorManager().attachInternal(
            UUID.randomUUID().toString(),
            new AutoCloseable[]{this.predictor}
        );
    }

    @Override
    public NDList processInput(TranslatorContext ctx, CustomInput input) {
        // Store input for later use in processOutput
        ctx.setAttachment("input", input);

        // Tokenize text with candidates
        String[] candidateArray = input.getCandidates();
        String text = input.getText();

        // Process input to create and return an NDList
        // ... implementation details ...

        return ndList;
    }

    @Override
    public CustomOutput processOutput(TranslatorContext ctx, NDList list)
            throws TranslateException {
        CustomInput input = (CustomInput)ctx.getAttachment("input");

        // Extract logits from model output
        NDArray logits = list.get(0);

        // Process the output (NDList list) to extract labels and scores
        // ... implementation details ...

        return new CustomOutput(input.getText(), labels, scores);
    }
}

Custom Translator Factory

public class CustomTranslatorFactory implements TranslatorFactory, Serializable {

    @Override
    public <I, O> Translator<I, O> newInstance(
            Class<I> input,
            Class<O> output,
            Model model,
            Map<String, ?> arguments) {

        Path modelPath = model.getModelPath();
        try {
            HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(arguments)
                .optTokenizerPath(modelPath)
                .optManager(model.getNDManager())
                .build();

            boolean int32 = Boolean.parseBoolean(
                arguments.getOrDefault("int32", "false").toString()
            );

            return (Translator<I, O>) new CustomTranslator(tokenizer, int32);
        } catch (IOException e) {
            throw new RuntimeException("Failed to load tokenizer", e);
        }
    }
}

Step 4: Extend MlComputeJob

Create custom compute job classes that extend MlComputeJob and implement the predictAsync function.

For Simple Predictions

public class CustomComputeJob<I extends MlSimpleJobParameters, O>
        extends MlComputeJob<I, O> {

    @Override
    public CompletableFuture<O> predictAsync(JobExecutionContext context, I arg) {
        if (arg == null) {
            return CompletableFuture.completedFuture(null);
        }
        return context.ignite().ml().predictAsync(arg);
    }

    @Override
    public Marshaller<I, byte[]> inputMarshaller() {
        return new CustomInputMarshaller<>();
    }

    @Override
    public Marshaller<O, byte[]> resultMarshaller() {
        return new CustomOutputMarshaller<>();
    }
}

For Batch Predictions

public class CustomBatchComputeJob<I extends MlBatchJobParameters, O>
        extends MlComputeJob<I, List<O>> {

    @Override
    public CompletableFuture<List<O>> predictAsync(
            JobExecutionContext context, I arg) {
        if (arg == null) {
            return CompletableFuture.completedFuture(null);
        }
        return context.ignite().ml().batchPredictAsync(arg);
    }

    @Override
    public Marshaller<I, byte[]> inputMarshaller() {
        return new CustomInputMarshaller<>();
    }

    @Override
    public Marshaller<List<O>, byte[]> resultMarshaller() {
        return new CustomOutputListMarshaller<>();
    }
}

Step 5: Implement Custom Marshallers

Input Marshaller

public class CustomInputMarshaller<I extends MlJobParameters>
        implements ByteArrayMarshaller<I>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(I obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal MlJobParameters", e);
        }
    }

    @Override
    public I unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (I) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal MlJobParameters", e);
        }
    }
}

Output Marshaller

public class CustomOutputMarshaller<O>
        implements ByteArrayMarshaller<O>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(O obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal Output", e);
        }
    }

    @Override
    public O unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (O) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal Output", e);
        }
    }
}

Output List Marshaller (for Batch Operations)

public class CustomOutputListMarshaller<O>
        implements ByteArrayMarshaller<List<O>>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(List<O> obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal Output list", e);
        }
    }

    @Override
    public List<O> unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (List<O>) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal Output list", e);
        }
    }
}

Step 6: Usage

Simple Prediction

CustomInput input = new CustomInput(
    "This is a great product!",
    new String[]{"positive", "negative", "neutral"},
    false
);

MlSimpleJobParameters<CustomInput> jobParams =
    MlSimpleJobParameters.<CustomInput>builder()
        .id(MODEL_ID)
        .version(MODEL_VERSION)
        .type(ModelType.ONNX)
        .name("model_quantized")
        .config(ModelConfig.builder().build())
        .inputClass("org.apache.ignite.example.ml.custom.CustomInput")
        .outputClass("org.apache.ignite.example.ml.custom.CustomOutput")
        .translatorFactory(
            "org.apache.ignite.example.ml.custom.CustomTranslatorFactory")
        .input(input)
        .customJobClass(CustomComputeJob.class)
        .customInputMarshaller(new CustomInputMarshaller<>())
        .customOutputMarshaller(new CustomOutputMarshaller<>())
        .build();

CustomOutput result = mlApi.predict(jobParams);

System.out.println("Classification Results:");
System.out.println("Text: " + result.getSequence());
System.out.println("Predictions: " + result.getLabels()[0]);
System.out.println("Scores: " + result.getScores()[0]);

Batch Prediction

List<CustomInput> batchInputs = Arrays.asList(
    new CustomInput("Great product!", new String[]{"positive", "negative"}, false),
    new CustomInput("Poor quality.", new String[]{"positive", "negative"}, false),
    new CustomInput("Average experience.", new String[]{"positive", "negative"}, false)
);

MlBatchJobParameters<CustomInput> jobParams =
    MlBatchJobParameters.<CustomInput>builder()
        .id(MODEL_ID)
        .version(MODEL_VERSION)
        .type(ModelType.ONNX)
        .name("model_quantized")
        .config(ModelConfig.builder().batchSize(16).build())
        .inputClass("org.apache.ignite.example.ml.custom.CustomInput")
        .outputClass("org.apache.ignite.example.ml.custom.CustomOutput")
        .translatorFactory(
            "org.apache.ignite.example.ml.custom.CustomTranslatorFactory")
        .batchInput(batchInputs)
        .customJobClass(CustomBatchComputeJob.class)
        .customInputMarshaller(new CustomInputMarshaller<>())
        .customOutputMarshaller(new CustomOutputListMarshaller<>())
        .build();

List<CustomOutput> results = mlApi.batchPredict(jobParams);

// Process results
for (int i = 0; i < results.size(); i++) {
    CustomOutput output = results.get(i);
    System.out.println("Input: " + batchInputs.get(i).getText());
    System.out.println("Result: " + output.getLabels()[0]);
}

Example 2: PyTorch Question Answering

This example demonstrates question answering using a PyTorch model from HuggingFace with user-defined input type, translator, and translator factory.

Key features:

  • PyTorch model hosted with DJL via djl:// URLs

  • User-defined input type (PytorchQAInput) for question-context pairs

  • Standard output type (String) for extracted answers

  • User-defined translator factory (PytorchQATranslatorFactory) for QA-specific processing

  • User-defined compute job implementation (PytorchQAComputeJob)

  • User-defined marshallers for specialized serialization

  • Demonstrates simple predictions only

Execution Mode: MODE.EMBEDDED

Step 1: Download Model

This example uses automated download from DJL Model Zoo:

private static final String MODEL_URL =
    "djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2";

// Download model files and deploy to GridGain cluster
tempModelDir = ModelUtils.downloadAndDeployDJLModelAndJar(
    MODEL_URL,
    PytorchQAInput.class,
    String.class,
    MODEL_ID,
    MODEL_VERSION,
    sourceDir,
    PytorchQATranslatorFactory.class.getName()
);

This method automatically:

  • Downloads the PyTorch model from DJL model zoo using the djl:// protocol

  • Creates a temporary directory to store the model files

  • Compiles custom classes (input types, translators, marshallers) from the source directory

  • Generates a JAR containing the custom compiled classes

  • Deploys both the model files and custom classes JAR to the GridGain cluster

Step 2: Implement Custom Input Type

public class PytorchQAInput implements Serializable {
    private static final long serialVersionUID = 1L;

    private final String question;
    private final String paragraph;
    private String context;

    public PytorchQAInput(String question, String paragraph) {
        this.question = question;
        this.paragraph = paragraph;
    }

    public String getQuestion() {
        return this.question;
    }

    public String getParagraph() {
        return this.paragraph == null ? this.context : this.paragraph;
    }
}

Step 3: Implement PyTorch Translator and TranslatorFactory

Translator Factory

public class PytorchQATranslatorFactory implements TranslatorFactory, Serializable {

    @Override
    public <I, O> Translator<I, O> newInstance(
            Class<I> input,
            Class<O> output,
            Model model,
            Map<String, ?> arguments) throws TranslateException {

        Path modelPath = model.getModelPath();
        try {
            HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(arguments)
                .optTokenizerPath(modelPath)
                .optManager(model.getNDManager())
                .build();

            PytorchQATranslator translator =
                PytorchQATranslator.builder(tokenizer, arguments).build();

            return (Translator<I, O>) translator;
        } catch (IOException e) {
            throw new TranslateException("Failed to load tokenizer.", e);
        }
    }
}

Translator

public class PytorchQATranslator implements Translator<PytorchQAInput, String> {

    private HuggingFaceTokenizer tokenizer;
    private Batchifier batchifier;

    private PytorchQATranslator(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        this.tokenizer = tokenizer;
        this.batchifier = Batchifier.STACK;
    }

    @Override
    public NDList processInput(TranslatorContext ctx, PytorchQAInput input) {
        // Tokenize question and paragraph together
        String question = input.getQuestion();
        String paragraph = input.getParagraph();

        // Process with tokenizer
        Encoding encoding = tokenizer.encode(question, paragraph);

        // Convert to NDArrays
        // ... implementation details ...

        return ndList;
    }

    @Override
    public String processOutput(TranslatorContext ctx, NDList list) {
        // Extract start and end logits
        NDArray startLogits = list.get(0);
        NDArray endLogits = list.get(1);

        // Find best answer span
        // ... implementation details ...

        // Extract answer from tokens
        String answer = extractAnswer(startIdx, endIdx);
        return answer;
    }
}

Step 4: Extend MlComputeJob

public class PytorchQAComputeJob<I extends MlSimpleJobParameters, O>
        extends MlComputeJob<I, O> {

    @Override
    public CompletableFuture<O> predictAsync(JobExecutionContext context, I arg) {
        if (arg == null) {
            return CompletableFuture.completedFuture(null);
        }
        return context.ignite().ml().predictAsync(arg);
    }

    @Override
    public Marshaller<I, byte[]> inputMarshaller() {
        return new PytorchQAInputMarshaller<>();
    }

    @Override
    public Marshaller<O, byte[]> resultMarshaller() {
        return new PytorchQAOutputMarshaller<>();
    }
}

Step 5: Implement Custom Marshallers

Input Marshaller for PytorchQAInput

public class PytorchQAInputMarshaller<I extends MlJobParameters>
        implements ByteArrayMarshaller<I>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(I obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal MlJobParameters", e);
        }
    }

    @Override
    public I unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (I) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal MlJobParameters", e);
        }
    }
}

Output Marshaller for String Results

public class PytorchQAOutputMarshaller<O>
        implements ByteArrayMarshaller<O>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(O obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal Output", e);
        }
    }

    @Override
    public O unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (O) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal Output", e);
        }
    }
}

Step 6: Usage

String question = "When did BBC Japan start broadcasting?";
String paragraph = "BBC Japan was a general entertainment Channel. " +
    "Which operated between December 2004 and April 2006. " +
    "It ceased operations after its Japanese distributor folded.";

PytorchQAInput qaInput = new PytorchQAInput(question, paragraph);

MlSimpleJobParameters<PytorchQAInput> jobParams =
    MlSimpleJobParameters.<PytorchQAInput>builder()
        .id(MODEL_ID)
        .version(MODEL_VERSION)
        .type(ModelType.PYTORCH)
        .name("minilm-uncased-squad2")
        .inputClass(PytorchQAInput.class.getName())
        .outputClass(String.class.getName())
        .translatorFactory(PytorchQATranslatorFactory.class.getName())
        .property("detail", "true")
        .input(qaInput)
        .customJobClass(PytorchQAComputeJob.class)
        .customInputMarshaller(new PytorchQAInputMarshaller<>())
        .customOutputMarshaller(new PytorchQAOutputMarshaller<>())
        .build();

String answer = mlApi.predict(jobParams);

System.out.println("Question: \"" + question + "\"");
System.out.println("Answer: \"" + answer + "\"");
// Output: Answer: "December 2004"

Example 3: TensorFlow Sentence Encoder

This example demonstrates sentence embedding generation using a TensorFlow SavedModel from HuggingFace with user-defined translators and marshallers.

Key features:

  • TensorFlow model hosted on HuggingFace

  • The model is downloaded via URLs with automated deployment

  • TensorFlow folder structure handling (saved_model.pb + variables/ directory)

  • Standard input type (String[]) for batch sentence processing

  • Standard output type (float[][]) for embedding vectors

  • User-defined translator (SentenceEncoderTranslator) for TensorFlow-specific processing

  • User-defined compute job implementation (TensorflowSentenceEncoderComputeJob)

  • User-defined marshallers

Step 1: Download Model

Option A: Manual Download

Download the Universal Sentence Encoder model files from HuggingFace:

# Download the main model file
curl --ssl-no-revoke -O \
  "https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/saved_model.pb"

# Download variable files
curl --ssl-no-revoke -O \
  "https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.data-00000-of-00001"

curl --ssl-no-revoke -O \
  "https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.index"
model-directory/
├── saved_model.pb
└── variables/
    ├── variables.data-00000-of-00001
    └── variables.index

After organizing, deploy the model to GridGain:

cluster unit deploy universal-sentence-encoder \
    --version 1.0.0 \
    --path /path/to/model/directory \
    --nodes ALL

Option B: Automated Download and Deployment

private static final List<String> urls = Arrays.asList(
    "https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/saved_model.pb",
    "https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.data-00000-of-00001",
    "https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.index"
);

// Construct path to custom classes
Path workingDir = Paths.get(System.getProperty("user.dir"));
Path sourceDir = workingDir.resolve("java/src/main/java/" + CUSTOM_CLASSES_PACKAGE);

// Download model files with proper folder structure and deploy
tempModelDir = ModelUtils.downloadAndDeployModelAndJarWithFolderStructure(
    urls,           // URLs to download model files from
    MODEL_ID,       // Model identifier (e.g., "universal-sentence-encoder")
    MODEL_VERSION,  // Model version (e.g., "1.0.0")
    sourceDir       // Path to custom classes for JAR packaging
);

This method automatically:

  • Downloads all required TensorFlow model files from the specified URLs

  • Creates the proper folder structure (saved_model.pb in root, variables in variables/ subdirectory)

  • Creates a temporary directory for the model files

  • Packages custom classes (translators, marshallers) into a JAR

  • Deploys both the model and custom classes to the GridGain cluster

Step 2: Input/Output Types

For this example, we use standard Java types:

  • Input Type: String[] - Array of sentences to encode

  • Output Type: float[][] - 2D array of embeddings (one embedding vector per sentence)

Since we’re using standard types, no custom input/output classes are needed. However, custom marshallers are required for proper serialization.

Step 3: Implement Custom Translator

public class SentenceEncoderTranslator
        implements NoBatchifyTranslator<String[], float[][]> {

    @Override
    public NDList processInput(TranslatorContext ctx, String[] inputs) {
        // Convert string array to NDArray format expected by TensorFlow model
        NDManager manager = ctx.getNDManager();

        List<NDArray> inputList = Arrays.stream(inputs)
            .map(manager::create)
            .collect(Collectors.toList());

        NDArray inputArray = NDArrays.stack(new NDList(inputList));
        return new NDList(inputArray);
    }

    @Override
    public float[][] processOutput(TranslatorContext ctx, NDList list) {
        // Extract embeddings from model output
        NDArray embeddings = list.singletonOrThrow();

        // Convert to float[][] for easier handling
        NDList result = new NDList();
        long numOutputs = embeddings.getShape().get(0);
        for (int i = 0; i < numOutputs; i++) {
            result.add(embeddings.get(i));
        }

        return result.stream()
            .map(NDArray::toFloatArray)
            .toArray(float[][]::new);
    }
}

Step 4: Extend MlComputeJob

public class TensorflowSentenceEncoderComputeJob<I extends MlSimpleJobParameters, O>
        extends MlComputeJob<I, O> {

    @Override
    public CompletableFuture<O> predictAsync(JobExecutionContext context, I arg) {
        if (arg == null) {
            return CompletableFuture.completedFuture(null);
        }
        return context.ignite().ml().predictAsync(arg);
    }

    @Override
    public Marshaller<I, byte[]> inputMarshaller() {
        return new TensorFlowInputMarshaller<>();
    }

    @Override
    public Marshaller<O, byte[]> resultMarshaller() {
        return new TensorFlowOutputMarshaller<>();
    }
}

Step 5: Implement Custom Marshallers

Input Marshaller for String[] Input

public class TensorFlowInputMarshaller<I extends MlJobParameters>
        implements ByteArrayMarshaller<I>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(I obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal MlJobParameters", e);
        }
    }

    @Override
    public I unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (I) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal MlJobParameters", e);
        }
    }
}

Output Marshaller for float[][] Output

public class TensorFlowOutputMarshaller<O>
        implements ByteArrayMarshaller<O>, Serializable {

    private static final long serialVersionUID = 1L;

    @Override
    public byte[] marshal(O obj) throws MarshallingException {
        if (obj == null) {
            return null;
        }
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeObject(obj);
            oos.flush();
            return baos.toByteArray();
        } catch (IOException e) {
            throw new MarshallingException("Failed to marshal Output", e);
        }
    }

    @Override
    public O unmarshal(byte[] data) throws UnmarshallingException {
        if (data == null) {
            return null;
        }
        try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            return (O) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new UnmarshallingException("Failed to unmarshal Output", e);
        }
    }
}

Step 6: Usage

String[] input = {"The quick brown fox jumps over the lazy dog."};

MlSimpleJobParameters<String[]> jobParams =
    MlSimpleJobParameters.<String[]>builder()
        .id(MODEL_ID)
        .version(MODEL_VERSION)
        .type(ModelType.TENSORFLOW)
        .name("saved_model") // TensorFlow saved_model format
        .inputClass("java.lang.String[]")
        .outputClass("float[][]")
        .translator(
            "org.apache.ignite.example.ml.tensorflowtranslator.SentenceEncoderTranslator")
        .input(input)
        .customJobClass(TensorflowSentenceEncoderComputeJob.class)
        .customInputMarshaller(new TensorFlowInputMarshaller<>())
        .customOutputMarshaller(new TensorFlowOutputMarshaller<>())
        .build();

float[][] result = mlApi.predict(jobParams);

System.out.println("Input: " + input[0] + "\n");
System.out.println("Embedding generated: ");
System.out.println("Dimensions: " + result[0].length);
System.out.println("First 5 values: " + Arrays.toString(Arrays.copyOf(result[0], 5)));

// Output example:
// Input: "The quick brown fox jumps over the lazy dog."
// Embedding generated:
// Dimensions: 512
// First 5 values: [0.0234, -0.0156, 0.0789, -0.0023, 0.0445]