GridGain ML Advanced Examples
This document provides complete, working examples of advanced GridGain ML scenarios that require custom compute jobs, translators, and marshallers.
Overview
Complete working examples are available in the examples directory. Each example demonstrates:
-
Downloading models (manually or programmatically)
-
Packaging user-defined input types, translators, and translator factories (if used) into a JAR
-
Deploying models along with the JAR containing user-defined code
-
Running predictions via Client API or Embedded API
Example 1: Custom Translator with ONNX Model
This example demonstrates zero-shot classification using an ONNX model with fully custom input/output types, translators, and marshallers.
Key features:
-
ONNX model hosted on HuggingFace
-
User-defined input type (
CustomInput) with multiple fields -
User-defined output type (
CustomOutput) with structured results -
User-defined translator factory for model-specific processing
-
User-defined marshallers for serialization
-
Demonstrates both simple and batch predictions
Step 1: Download Model
You can download the model files using either manual or automated methods.
Option A: Manual Download
Download a zero-shot classification model files from HuggingFace:
curl --ssl-no-revoke -O \
"https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/onnx/model_quantized.onnx"
curl --ssl-no-revoke -O \
"https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/tokenizer.json?download=true"
After download, deploy the model to GridGain:
# Deploy model to cluster
cluster unit deploy zeroshot-model \
--version 1.0.0 \
--path /path/to/model/directory \
--nodes ALL
Option B: Automated Download and Deployment
private static final List<String> urls = Arrays.asList(
"https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/onnx/model_quantized.onnx",
"https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/tokenizer.json?download=true"
);
// Download model files and deploy to GridGain cluster
tempModelDir = ModelUtils.downloadAndDeployModelAndJar(
urls, // URLs to download model files from
MODEL_ID, // Model identifier (e.g., "zeroshot")
MODEL_VERSION, // Model version (e.g., "1.0.0")
sourceDir // Path to custom classes for JAR packaging
);
This method automatically:
-
Downloads all required model files from the specified URLs
-
Creates a temporary directory for the model files
-
Packages custom classes (translators, marshallers) into a JAR
-
Deploys both the model and custom classes to the GridGain cluster
Step 2: Implement Custom Input/Output Types
public class CustomInput implements Serializable {
private String text;
private String[] candidates;
private boolean multiLabel;
public CustomInput(String text, String[] candidates, boolean multiLabel) {
this.text = text;
this.candidates = candidates;
this.multiLabel = multiLabel;
}
// getters and setters...
}
public class CustomOutput implements Serializable {
private String sequence;
private String[] labels;
private double[] scores;
public CustomOutput(String sequence, String[] labels, double[] scores) {
this.sequence = sequence;
this.labels = labels;
this.scores = scores;
}
// getters and setters...
}
Step 3: Implement Custom Translator and TranslatorFactory
Custom Translator
public class CustomTranslator implements NoBatchifyTranslator<CustomInput, CustomOutput> {
private HuggingFaceTokenizer tokenizer;
private boolean int32;
private Predictor<NDList, NDList> predictor;
private CustomTranslator(HuggingFaceTokenizer tokenizer, boolean int32) {
this.tokenizer = tokenizer;
this.int32 = int32;
}
@Override
public void prepare(TranslatorContext ctx) throws IOException, ModelException {
Model model = ctx.getModel();
this.predictor = model.newPredictor(new NoopTranslator((Batchifier)null));
ctx.getPredictorManager().attachInternal(
UUID.randomUUID().toString(),
new AutoCloseable[]{this.predictor}
);
}
@Override
public NDList processInput(TranslatorContext ctx, CustomInput input) {
// Store input for later use in processOutput
ctx.setAttachment("input", input);
// Tokenize text with candidates
String[] candidateArray = input.getCandidates();
String text = input.getText();
// Process input to create and return an NDList
// ... implementation details ...
return ndList;
}
@Override
public CustomOutput processOutput(TranslatorContext ctx, NDList list)
throws TranslateException {
CustomInput input = (CustomInput)ctx.getAttachment("input");
// Extract logits from model output
NDArray logits = list.get(0);
// Process the output (NDList list) to extract labels and scores
// ... implementation details ...
return new CustomOutput(input.getText(), labels, scores);
}
}
Custom Translator Factory
public class CustomTranslatorFactory implements TranslatorFactory, Serializable {
@Override
public <I, O> Translator<I, O> newInstance(
Class<I> input,
Class<O> output,
Model model,
Map<String, ?> arguments) {
Path modelPath = model.getModelPath();
try {
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(arguments)
.optTokenizerPath(modelPath)
.optManager(model.getNDManager())
.build();
boolean int32 = Boolean.parseBoolean(
arguments.getOrDefault("int32", "false").toString()
);
return (Translator<I, O>) new CustomTranslator(tokenizer, int32);
} catch (IOException e) {
throw new RuntimeException("Failed to load tokenizer", e);
}
}
}
Step 4: Extend MlComputeJob
Create custom compute job classes that extend MlComputeJob and implement the predictAsync function.
For Simple Predictions
public class CustomComputeJob<I extends MlSimpleJobParameters, O>
extends MlComputeJob<I, O> {
@Override
public CompletableFuture<O> predictAsync(JobExecutionContext context, I arg) {
if (arg == null) {
return CompletableFuture.completedFuture(null);
}
return context.ignite().ml().predictAsync(arg);
}
@Override
public Marshaller<I, byte[]> inputMarshaller() {
return new CustomInputMarshaller<>();
}
@Override
public Marshaller<O, byte[]> resultMarshaller() {
return new CustomOutputMarshaller<>();
}
}
For Batch Predictions
public class CustomBatchComputeJob<I extends MlBatchJobParameters, O>
extends MlComputeJob<I, List<O>> {
@Override
public CompletableFuture<List<O>> predictAsync(
JobExecutionContext context, I arg) {
if (arg == null) {
return CompletableFuture.completedFuture(null);
}
return context.ignite().ml().batchPredictAsync(arg);
}
@Override
public Marshaller<I, byte[]> inputMarshaller() {
return new CustomInputMarshaller<>();
}
@Override
public Marshaller<List<O>, byte[]> resultMarshaller() {
return new CustomOutputListMarshaller<>();
}
}
Step 5: Implement Custom Marshallers
Input Marshaller
public class CustomInputMarshaller<I extends MlJobParameters>
implements ByteArrayMarshaller<I>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(I obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal MlJobParameters", e);
}
}
@Override
public I unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (I) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal MlJobParameters", e);
}
}
}
Output Marshaller
public class CustomOutputMarshaller<O>
implements ByteArrayMarshaller<O>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(O obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal Output", e);
}
}
@Override
public O unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (O) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal Output", e);
}
}
}
Output List Marshaller (for Batch Operations)
public class CustomOutputListMarshaller<O>
implements ByteArrayMarshaller<List<O>>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(List<O> obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal Output list", e);
}
}
@Override
public List<O> unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (List<O>) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal Output list", e);
}
}
}
Step 6: Usage
Simple Prediction
CustomInput input = new CustomInput(
"This is a great product!",
new String[]{"positive", "negative", "neutral"},
false
);
MlSimpleJobParameters<CustomInput> jobParams =
MlSimpleJobParameters.<CustomInput>builder()
.id(MODEL_ID)
.version(MODEL_VERSION)
.type(ModelType.ONNX)
.name("model_quantized")
.config(ModelConfig.builder().build())
.inputClass("org.apache.ignite.example.ml.custom.CustomInput")
.outputClass("org.apache.ignite.example.ml.custom.CustomOutput")
.translatorFactory(
"org.apache.ignite.example.ml.custom.CustomTranslatorFactory")
.input(input)
.customJobClass(CustomComputeJob.class)
.customInputMarshaller(new CustomInputMarshaller<>())
.customOutputMarshaller(new CustomOutputMarshaller<>())
.build();
CustomOutput result = mlApi.predict(jobParams);
System.out.println("Classification Results:");
System.out.println("Text: " + result.getSequence());
System.out.println("Predictions: " + result.getLabels()[0]);
System.out.println("Scores: " + result.getScores()[0]);
Batch Prediction
List<CustomInput> batchInputs = Arrays.asList(
new CustomInput("Great product!", new String[]{"positive", "negative"}, false),
new CustomInput("Poor quality.", new String[]{"positive", "negative"}, false),
new CustomInput("Average experience.", new String[]{"positive", "negative"}, false)
);
MlBatchJobParameters<CustomInput> jobParams =
MlBatchJobParameters.<CustomInput>builder()
.id(MODEL_ID)
.version(MODEL_VERSION)
.type(ModelType.ONNX)
.name("model_quantized")
.config(ModelConfig.builder().batchSize(16).build())
.inputClass("org.apache.ignite.example.ml.custom.CustomInput")
.outputClass("org.apache.ignite.example.ml.custom.CustomOutput")
.translatorFactory(
"org.apache.ignite.example.ml.custom.CustomTranslatorFactory")
.batchInput(batchInputs)
.customJobClass(CustomBatchComputeJob.class)
.customInputMarshaller(new CustomInputMarshaller<>())
.customOutputMarshaller(new CustomOutputListMarshaller<>())
.build();
List<CustomOutput> results = mlApi.batchPredict(jobParams);
// Process results
for (int i = 0; i < results.size(); i++) {
CustomOutput output = results.get(i);
System.out.println("Input: " + batchInputs.get(i).getText());
System.out.println("Result: " + output.getLabels()[0]);
}
Example 2: PyTorch Question Answering
This example demonstrates question answering using a PyTorch model from HuggingFace with user-defined input type, translator, and translator factory.
Key features:
-
PyTorch model hosted with DJL via
djl://URLs -
User-defined input type (
PytorchQAInput) for question-context pairs -
Standard output type (
String) for extracted answers -
User-defined translator factory (
PytorchQATranslatorFactory) for QA-specific processing -
User-defined compute job implementation (
PytorchQAComputeJob) -
User-defined marshallers for specialized serialization
-
Demonstrates simple predictions only
Execution Mode: MODE.EMBEDDED
Step 1: Download Model
This example uses automated download from DJL Model Zoo:
private static final String MODEL_URL =
"djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2";
// Download model files and deploy to GridGain cluster
tempModelDir = ModelUtils.downloadAndDeployDJLModelAndJar(
MODEL_URL,
PytorchQAInput.class,
String.class,
MODEL_ID,
MODEL_VERSION,
sourceDir,
PytorchQATranslatorFactory.class.getName()
);
This method automatically:
-
Downloads the PyTorch model from DJL model zoo using the
djl://protocol -
Creates a temporary directory to store the model files
-
Compiles custom classes (input types, translators, marshallers) from the source directory
-
Generates a JAR containing the custom compiled classes
-
Deploys both the model files and custom classes JAR to the GridGain cluster
Step 2: Implement Custom Input Type
public class PytorchQAInput implements Serializable {
private static final long serialVersionUID = 1L;
private final String question;
private final String paragraph;
private String context;
public PytorchQAInput(String question, String paragraph) {
this.question = question;
this.paragraph = paragraph;
}
public String getQuestion() {
return this.question;
}
public String getParagraph() {
return this.paragraph == null ? this.context : this.paragraph;
}
}
Step 3: Implement PyTorch Translator and TranslatorFactory
Translator Factory
public class PytorchQATranslatorFactory implements TranslatorFactory, Serializable {
@Override
public <I, O> Translator<I, O> newInstance(
Class<I> input,
Class<O> output,
Model model,
Map<String, ?> arguments) throws TranslateException {
Path modelPath = model.getModelPath();
try {
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(arguments)
.optTokenizerPath(modelPath)
.optManager(model.getNDManager())
.build();
PytorchQATranslator translator =
PytorchQATranslator.builder(tokenizer, arguments).build();
return (Translator<I, O>) translator;
} catch (IOException e) {
throw new TranslateException("Failed to load tokenizer.", e);
}
}
}
Translator
public class PytorchQATranslator implements Translator<PytorchQAInput, String> {
private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier;
private PytorchQATranslator(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
this.tokenizer = tokenizer;
this.batchifier = Batchifier.STACK;
}
@Override
public NDList processInput(TranslatorContext ctx, PytorchQAInput input) {
// Tokenize question and paragraph together
String question = input.getQuestion();
String paragraph = input.getParagraph();
// Process with tokenizer
Encoding encoding = tokenizer.encode(question, paragraph);
// Convert to NDArrays
// ... implementation details ...
return ndList;
}
@Override
public String processOutput(TranslatorContext ctx, NDList list) {
// Extract start and end logits
NDArray startLogits = list.get(0);
NDArray endLogits = list.get(1);
// Find best answer span
// ... implementation details ...
// Extract answer from tokens
String answer = extractAnswer(startIdx, endIdx);
return answer;
}
}
Step 4: Extend MlComputeJob
public class PytorchQAComputeJob<I extends MlSimpleJobParameters, O>
extends MlComputeJob<I, O> {
@Override
public CompletableFuture<O> predictAsync(JobExecutionContext context, I arg) {
if (arg == null) {
return CompletableFuture.completedFuture(null);
}
return context.ignite().ml().predictAsync(arg);
}
@Override
public Marshaller<I, byte[]> inputMarshaller() {
return new PytorchQAInputMarshaller<>();
}
@Override
public Marshaller<O, byte[]> resultMarshaller() {
return new PytorchQAOutputMarshaller<>();
}
}
Step 5: Implement Custom Marshallers
Input Marshaller for PytorchQAInput
public class PytorchQAInputMarshaller<I extends MlJobParameters>
implements ByteArrayMarshaller<I>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(I obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal MlJobParameters", e);
}
}
@Override
public I unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (I) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal MlJobParameters", e);
}
}
}
Output Marshaller for String Results
public class PytorchQAOutputMarshaller<O>
implements ByteArrayMarshaller<O>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(O obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal Output", e);
}
}
@Override
public O unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (O) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal Output", e);
}
}
}
Step 6: Usage
String question = "When did BBC Japan start broadcasting?";
String paragraph = "BBC Japan was a general entertainment Channel. " +
"Which operated between December 2004 and April 2006. " +
"It ceased operations after its Japanese distributor folded.";
PytorchQAInput qaInput = new PytorchQAInput(question, paragraph);
MlSimpleJobParameters<PytorchQAInput> jobParams =
MlSimpleJobParameters.<PytorchQAInput>builder()
.id(MODEL_ID)
.version(MODEL_VERSION)
.type(ModelType.PYTORCH)
.name("minilm-uncased-squad2")
.inputClass(PytorchQAInput.class.getName())
.outputClass(String.class.getName())
.translatorFactory(PytorchQATranslatorFactory.class.getName())
.property("detail", "true")
.input(qaInput)
.customJobClass(PytorchQAComputeJob.class)
.customInputMarshaller(new PytorchQAInputMarshaller<>())
.customOutputMarshaller(new PytorchQAOutputMarshaller<>())
.build();
String answer = mlApi.predict(jobParams);
System.out.println("Question: \"" + question + "\"");
System.out.println("Answer: \"" + answer + "\"");
// Output: Answer: "December 2004"
Example 3: TensorFlow Sentence Encoder
This example demonstrates sentence embedding generation using a TensorFlow SavedModel from HuggingFace with user-defined translators and marshallers.
Key features:
-
TensorFlow model hosted on HuggingFace
-
The model is downloaded via URLs with automated deployment
-
TensorFlow folder structure handling (
saved_model.pb+variables/directory) -
Standard input type (
String[]) for batch sentence processing -
Standard output type (
float[][]) for embedding vectors -
User-defined translator (
SentenceEncoderTranslator) for TensorFlow-specific processing -
User-defined compute job implementation (
TensorflowSentenceEncoderComputeJob) -
User-defined marshallers
Step 1: Download Model
Option A: Manual Download
Download the Universal Sentence Encoder model files from HuggingFace:
# Download the main model file
curl --ssl-no-revoke -O \
"https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/saved_model.pb"
# Download variable files
curl --ssl-no-revoke -O \
"https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.data-00000-of-00001"
curl --ssl-no-revoke -O \
"https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.index"
model-directory/
├── saved_model.pb
└── variables/
├── variables.data-00000-of-00001
└── variables.index
After organizing, deploy the model to GridGain:
cluster unit deploy universal-sentence-encoder \
--version 1.0.0 \
--path /path/to/model/directory \
--nodes ALL
Option B: Automated Download and Deployment
private static final List<String> urls = Arrays.asList(
"https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/saved_model.pb",
"https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.data-00000-of-00001",
"https://huggingface.co/hfarwah/universal-sentence-encoder/resolve/main/variables/variables.index"
);
// Construct path to custom classes
Path workingDir = Paths.get(System.getProperty("user.dir"));
Path sourceDir = workingDir.resolve("java/src/main/java/" + CUSTOM_CLASSES_PACKAGE);
// Download model files with proper folder structure and deploy
tempModelDir = ModelUtils.downloadAndDeployModelAndJarWithFolderStructure(
urls, // URLs to download model files from
MODEL_ID, // Model identifier (e.g., "universal-sentence-encoder")
MODEL_VERSION, // Model version (e.g., "1.0.0")
sourceDir // Path to custom classes for JAR packaging
);
This method automatically:
-
Downloads all required TensorFlow model files from the specified URLs
-
Creates the proper folder structure (
saved_model.pbin root, variables invariables/subdirectory) -
Creates a temporary directory for the model files
-
Packages custom classes (translators, marshallers) into a JAR
-
Deploys both the model and custom classes to the GridGain cluster
Step 2: Input/Output Types
For this example, we use standard Java types:
-
Input Type:
String[]- Array of sentences to encode -
Output Type:
float[][]- 2D array of embeddings (one embedding vector per sentence)
Since we’re using standard types, no custom input/output classes are needed. However, custom marshallers are required for proper serialization.
Step 3: Implement Custom Translator
public class SentenceEncoderTranslator
implements NoBatchifyTranslator<String[], float[][]> {
@Override
public NDList processInput(TranslatorContext ctx, String[] inputs) {
// Convert string array to NDArray format expected by TensorFlow model
NDManager manager = ctx.getNDManager();
List<NDArray> inputList = Arrays.stream(inputs)
.map(manager::create)
.collect(Collectors.toList());
NDArray inputArray = NDArrays.stack(new NDList(inputList));
return new NDList(inputArray);
}
@Override
public float[][] processOutput(TranslatorContext ctx, NDList list) {
// Extract embeddings from model output
NDArray embeddings = list.singletonOrThrow();
// Convert to float[][] for easier handling
NDList result = new NDList();
long numOutputs = embeddings.getShape().get(0);
for (int i = 0; i < numOutputs; i++) {
result.add(embeddings.get(i));
}
return result.stream()
.map(NDArray::toFloatArray)
.toArray(float[][]::new);
}
}
Step 4: Extend MlComputeJob
public class TensorflowSentenceEncoderComputeJob<I extends MlSimpleJobParameters, O>
extends MlComputeJob<I, O> {
@Override
public CompletableFuture<O> predictAsync(JobExecutionContext context, I arg) {
if (arg == null) {
return CompletableFuture.completedFuture(null);
}
return context.ignite().ml().predictAsync(arg);
}
@Override
public Marshaller<I, byte[]> inputMarshaller() {
return new TensorFlowInputMarshaller<>();
}
@Override
public Marshaller<O, byte[]> resultMarshaller() {
return new TensorFlowOutputMarshaller<>();
}
}
Step 5: Implement Custom Marshallers
Input Marshaller for String[] Input
public class TensorFlowInputMarshaller<I extends MlJobParameters>
implements ByteArrayMarshaller<I>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(I obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal MlJobParameters", e);
}
}
@Override
public I unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (I) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal MlJobParameters", e);
}
}
}
Output Marshaller for float[][] Output
public class TensorFlowOutputMarshaller<O>
implements ByteArrayMarshaller<O>, Serializable {
private static final long serialVersionUID = 1L;
@Override
public byte[] marshal(O obj) throws MarshallingException {
if (obj == null) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(obj);
oos.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new MarshallingException("Failed to marshal Output", e);
}
}
@Override
public O unmarshal(byte[] data) throws UnmarshallingException {
if (data == null) {
return null;
}
try (ByteArrayInputStream bais = new ByteArrayInputStream(data);
ObjectInputStream ois = new ObjectInputStream(bais)) {
return (O) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new UnmarshallingException("Failed to unmarshal Output", e);
}
}
}
Step 6: Usage
String[] input = {"The quick brown fox jumps over the lazy dog."};
MlSimpleJobParameters<String[]> jobParams =
MlSimpleJobParameters.<String[]>builder()
.id(MODEL_ID)
.version(MODEL_VERSION)
.type(ModelType.TENSORFLOW)
.name("saved_model") // TensorFlow saved_model format
.inputClass("java.lang.String[]")
.outputClass("float[][]")
.translator(
"org.apache.ignite.example.ml.tensorflowtranslator.SentenceEncoderTranslator")
.input(input)
.customJobClass(TensorflowSentenceEncoderComputeJob.class)
.customInputMarshaller(new TensorFlowInputMarshaller<>())
.customOutputMarshaller(new TensorFlowOutputMarshaller<>())
.build();
float[][] result = mlApi.predict(jobParams);
System.out.println("Input: " + input[0] + "\n");
System.out.println("Embedding generated: ");
System.out.println("Dimensions: " + result[0].length);
System.out.println("First 5 values: " + Arrays.toString(Arrays.copyOf(result[0], 5)));
// Output example:
// Input: "The quick brown fox jumps over the lazy dog."
// Embedding generated:
// Dimensions: 512
// First 5 values: [0.0234, -0.0156, 0.0789, -0.0023, 0.0445]
© 2025 GridGain Systems, Inc. All Rights Reserved. Privacy Policy | Legal Notices. GridGain® is a registered trademark of GridGain Systems, Inc.
Apache, Apache Ignite, the Apache feather and the Apache Ignite logo are either registered trademarks or trademarks of The Apache Software Foundation.