From 3aaad7dfbf92e86772cb500ac3e962d3c9c6132d Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 9 Jun 2026 16:38:12 +0530 Subject: [PATCH 1/3] [DOC] Clarify MODEL_TYPE_MULTIMODAL is an alias for MODEL_TYPE_TEXT_VISION Both constants share the same value (2) in LlmModuleConfig.java, so listing them as separate model types implied behavior that does not exist. Reword the available-model-types sentence to drop the misleading third entry and note the alias explicitly. Flagged on the original PR (#19611). --- docs/source/llm/run-on-android.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/llm/run-on-android.md b/docs/source/llm/run-on-android.md index 81abd6a79d5..8ce8f25a51e 100644 --- a/docs/source/llm/run-on-android.md +++ b/docs/source/llm/run-on-android.md @@ -52,7 +52,7 @@ LlmModuleConfig config = LlmModuleConfig.create() LlmModule module = new LlmModule(config); ``` -Available load modes are `LOAD_MODE_FILE`, `LOAD_MODE_MMAP` (default), `LOAD_MODE_MMAP_USE_MLOCK`, and `LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS`. Available model types are `MODEL_TYPE_TEXT`, `MODEL_TYPE_TEXT_VISION`, and `MODEL_TYPE_MULTIMODAL`. +Available load modes are `LOAD_MODE_FILE`, `LOAD_MODE_MMAP` (default), `LOAD_MODE_MMAP_USE_MLOCK`, and `LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS`. Available model types are `MODEL_TYPE_TEXT` and `MODEL_TYPE_TEXT_VISION` (the `MODEL_TYPE_MULTIMODAL` constant is currently an alias for `MODEL_TYPE_TEXT_VISION` and selects the same runtime path). Construction itself is lightweight and does not load the program data immediately. From dbf02a42bf24a4a5ba4468d64e0622b0df8c8657 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 9 Jun 2026 16:39:58 +0530 Subject: [PATCH 2/3] [DOC] Add Kotlin examples for the Android LLM runner Mirrors the Java snippets in run-on-android.md with idiomatic Kotlin equivalents (object expression for the LlmCallback, IntArray/FloatArray for primitive arrays, Float.SIZE_BYTES, ByteBuffer.apply { ... }, trailing commas on multi-line calls). Follows the run-on-ios.md convention of putting a language label above each fenced block rather than tab-sets. Requested in review on the original PR (#19611). --- docs/source/llm/run-on-android.md | 150 +++++++++++++++++++++++++++++- 1 file changed, 149 insertions(+), 1 deletion(-) diff --git a/docs/source/llm/run-on-android.md b/docs/source/llm/run-on-android.md index 8ce8f25a51e..0cec9420f59 100644 --- a/docs/source/llm/run-on-android.md +++ b/docs/source/llm/run-on-android.md @@ -10,10 +10,11 @@ To add the `executorch-android` library to your app, see [Using ExecuTorch on An ## Runtime API -Once the `executorch-android` AAR is on your classpath, you can import the LLM runner classes from the `org.pytorch.executorch.extension.llm` package. +Once the `executorch-android` AAR is on your classpath, you can import the LLM runner classes from the `org.pytorch.executorch.extension.llm` package. The runner is callable from both Java and Kotlin; the rest of this guide shows both side by side. ### Importing +Java: ```java import org.pytorch.executorch.extension.llm.LlmModule; import org.pytorch.executorch.extension.llm.LlmModuleConfig; @@ -21,6 +22,14 @@ import org.pytorch.executorch.extension.llm.LlmGenerationConfig; import org.pytorch.executorch.extension.llm.LlmCallback; ``` +Kotlin: +```kotlin +import org.pytorch.executorch.extension.llm.LlmModule +import org.pytorch.executorch.extension.llm.LlmModuleConfig +import org.pytorch.executorch.extension.llm.LlmGenerationConfig +import org.pytorch.executorch.extension.llm.LlmCallback +``` + ### LlmModule The `LlmModule` class provides a simple Java interface for loading a text-generation model, configuring its tokenizer, generating token streams, and stopping execution. It also supports multimodal models that accept image and audio inputs alongside a text prompt. @@ -31,6 +40,7 @@ This API is experimental and subject to change. Create an `LlmModule` by specifying paths to your serialized model (`.pte`) and tokenizer files. For text-only models, the simple constructor is enough: +Java: ```java LlmModule module = new LlmModule( "/data/local/tmp/llama-3.2-instruct.pte", @@ -38,8 +48,18 @@ LlmModule module = new LlmModule( 0.8f); ``` +Kotlin: +```kotlin +val module = LlmModule( + "/data/local/tmp/llama-3.2-instruct.pte", + "/data/local/tmp/tokenizer.model", + 0.8f, +) +``` + For finer control (multimodal model type, BOS/EOS handling, supplementary data files, load mode), use `LlmModuleConfig` with the fluent builder: +Java: ```java LlmModuleConfig config = LlmModuleConfig.create() .modulePath("/data/local/tmp/llama-3.2-instruct.pte") @@ -52,6 +72,19 @@ LlmModuleConfig config = LlmModuleConfig.create() LlmModule module = new LlmModule(config); ``` +Kotlin: +```kotlin +val config = LlmModuleConfig.create() + .modulePath("/data/local/tmp/llama-3.2-instruct.pte") + .tokenizerPath("/data/local/tmp/tokenizer.model") + .temperature(0.8f) + .modelType(LlmModuleConfig.MODEL_TYPE_TEXT) + .loadMode(LlmModuleConfig.LOAD_MODE_MMAP) + .build() + +val module = LlmModule(config) +``` + Available load modes are `LOAD_MODE_FILE`, `LOAD_MODE_MMAP` (default), `LOAD_MODE_MMAP_USE_MLOCK`, and `LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS`. Available model types are `MODEL_TYPE_TEXT` and `MODEL_TYPE_TEXT_VISION` (the `MODEL_TYPE_MULTIMODAL` constant is currently an alias for `MODEL_TYPE_TEXT_VISION` and selects the same runtime path). Construction itself is lightweight and does not load the program data immediately. @@ -60,6 +93,7 @@ Construction itself is lightweight and does not load the program data immediatel Explicitly load the model before generation to avoid paying the load cost during your first `generate` call. +Java: ```java int status = module.load(); if (status != 0) { @@ -67,12 +101,21 @@ if (status != 0) { } ``` +Kotlin: +```kotlin +val status = module.load() +if (status != 0) { + // Handle load failure (status is an ExecuTorch runtime error code). +} +``` + If you skip this step, the model is loaded lazily on the first `generate` call. #### Generating Generate tokens from a text prompt by passing an `LlmCallback` that receives each token as it is produced. The same callback also receives a JSON-encoded statistics string when generation completes. +Java: ```java LlmCallback callback = new LlmCallback() { @Override @@ -97,8 +140,31 @@ LlmCallback callback = new LlmCallback() { module.generate("Once upon a time", callback); ``` +Kotlin: +```kotlin +val callback = object : LlmCallback { + override fun onResult(token: String) { + // Called once per generated token. Append to your UI buffer here. + print(token) + } + + override fun onStats(statsJson: String) { + // Called once when generation finishes. See extension/llm/runner/stats.h + // for the field definitions. + println("\n$statsJson") + } + + override fun onError(errorCode: Int, message: String) { + // Called if the runtime reports an error during generation. + } +} + +module.generate("Once upon a time", callback) +``` + For full control over generation parameters, use `LlmGenerationConfig`: +Java: ```java LlmGenerationConfig genConfig = LlmGenerationConfig.create() .seqLen(2048) @@ -109,26 +175,49 @@ LlmGenerationConfig genConfig = LlmGenerationConfig.create() module.generate("Once upon a time", genConfig, callback); ``` +Kotlin: +```kotlin +val genConfig = LlmGenerationConfig.create() + .seqLen(2048) + .temperature(0.8f) + .echo(false) + .build() + +module.generate("Once upon a time", genConfig, callback) +``` + `LlmGenerationConfig` exposes `echo`, `maxNewTokens`, `seqLen`, `temperature`, `numBos`, `numEos`, and `warming`. Defaults match the C++ `GenerationConfig` documented in [Running LLMs with C++](run-with-c-plus-plus.md). #### Stopping Generation If you need to interrupt a long-running generation, call `stop()` from another thread (or from inside the `onResult` callback): +Java: ```java module.stop(); ``` +Kotlin: +```kotlin +module.stop() +``` + Generation also runs synchronously on the calling thread, so make sure you invoke `generate()` off the main thread (for example, on a `HandlerThread` or via a `java.util.concurrent.Executor`). #### Resetting To clear the prefilled tokens from the KV cache and reset the start position to 0, call: +Java: ```java module.resetContext(); ``` +Kotlin: +```kotlin +module.resetContext() +``` + This is the equivalent of `reset()` on the iOS runner and `reset()` on the C++ `IRunner`. ### Multimodal Inputs @@ -139,6 +228,7 @@ For models declared as `MODEL_TYPE_TEXT_VISION` or `MODEL_TYPE_MULTIMODAL`, imag Raw uint8 pixel data in CHW order can be supplied as an `int[]`, or as a direct `ByteBuffer` to avoid JNI array copies: +Java: ```java // As int[] int[] pixels = ...; // length == channels * height * width @@ -150,8 +240,23 @@ buffer.put(rawBytes).rewind(); module.prefillImages(buffer, 336, 336, 3); ``` +Kotlin: +```kotlin +// As IntArray +val pixels: IntArray = ... // length == channels * height * width +module.prefillImages(pixels, /* width = */ 336, /* height = */ 336, /* channels = */ 3) + +// As direct ByteBuffer (preferred for large images) +val buffer = ByteBuffer.allocateDirect(3 * 336 * 336).apply { + put(rawBytes) + rewind() +} +module.prefillImages(buffer, 336, 336, 3) +``` + Pre-normalized float pixel data is also supported, both as a `float[]` and as a direct `ByteBuffer` in native byte order: +Java: ```java float[] normalized = ...; // length == channels * height * width module.prefillImages(normalized, 336, 336, 3); @@ -163,31 +268,63 @@ ByteBuffer floatBuffer = ByteBuffer module.prefillNormalizedImage(floatBuffer, 336, 336, 3); ``` +Kotlin: +```kotlin +val normalized: FloatArray = ... // length == channels * height * width +module.prefillImages(normalized, 336, 336, 3) + +val floatBuffer: ByteBuffer = ByteBuffer + .allocateDirect(3 * 336 * 336 * Float.SIZE_BYTES) + .order(ByteOrder.nativeOrder()) +// fill floatBuffer with normalized values, then: +module.prefillNormalizedImage(floatBuffer, 336, 336, 3) +``` + #### Audio Preprocessed audio features (for example mel spectrograms produced by a Whisper preprocessor) can be supplied as `byte[]` or `float[]`: +Java: ```java module.prefillAudio(features, /*batchSize=*/1, /*nBins=*/128, /*nFrames=*/3000); ``` +Kotlin: +```kotlin +module.prefillAudio(features, /* batchSize = */ 1, /* nBins = */ 128, /* nFrames = */ 3000) +``` + Raw audio samples can be supplied with `prefillRawAudio`: +Java: ```java module.prefillRawAudio(samples, /*batchSize=*/1, /*nChannels=*/1, /*nSamples=*/16000); ``` +Kotlin: +```kotlin +module.prefillRawAudio(samples, /* batchSize = */ 1, /* nChannels = */ 1, /* nSamples = */ 16000) +``` + #### Generating with Multimodal Prefill After prefilling each modality, run `generate()` with the text prompt as usual: +Java: ```java module.prefillImages(pixels, 336, 336, 3); module.generate("What's in this image?", callback); ``` +Kotlin: +```kotlin +module.prefillImages(pixels, 336, 336, 3) +module.generate("What's in this image?", callback) +``` + For text-vision models, a convenience overload accepts the image and prompt together: +Java: ```java module.generate( pixels, /*width=*/336, /*height=*/336, /*channels=*/3, @@ -197,6 +334,17 @@ module.generate( /*echo=*/false); ``` +Kotlin: +```kotlin +module.generate( + pixels, /* width = */ 336, /* height = */ 336, /* channels = */ 3, + "What's in this image?", + /* seqLen = */ 768, + callback, + /* echo = */ false, +) +``` + ## Demo See the [Llama Android demo app](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android/LlamaDemo) in `executorch-examples` for an end-to-end project that wires `LlmModule`, `LlmCallback`, and a `HandlerThread` into a chat UI. From a0b062b61c6dc390eb6892f1233c0b7bd5ca3a9d Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 9 Jun 2026 18:19:27 +0530 Subject: [PATCH 3/3] [DOC] Add java.nio imports and rewind() to float ByteBuffer examples Address review feedback on the Android LLM runner page: * Add java.nio.ByteBuffer and java.nio.ByteOrder to the Java and Kotlin import blocks (with a one-line note that they are only used by the multimodal ByteBuffer paths), so the snippets in the Images section compile when copy-pasted. * Show floatBuffer.rewind() before prefillNormalizedImage in both the Java and Kotlin examples, matching the int-buffer example above and removing the asymmetry called out in review. --- docs/source/llm/run-on-android.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/source/llm/run-on-android.md b/docs/source/llm/run-on-android.md index 0cec9420f59..939a9c9f603 100644 --- a/docs/source/llm/run-on-android.md +++ b/docs/source/llm/run-on-android.md @@ -20,6 +20,10 @@ import org.pytorch.executorch.extension.llm.LlmModule; import org.pytorch.executorch.extension.llm.LlmModuleConfig; import org.pytorch.executorch.extension.llm.LlmGenerationConfig; import org.pytorch.executorch.extension.llm.LlmCallback; + +// Only needed for the multimodal ByteBuffer paths in the Images section. +import java.nio.ByteBuffer; +import java.nio.ByteOrder; ``` Kotlin: @@ -28,6 +32,10 @@ import org.pytorch.executorch.extension.llm.LlmModule import org.pytorch.executorch.extension.llm.LlmModuleConfig import org.pytorch.executorch.extension.llm.LlmGenerationConfig import org.pytorch.executorch.extension.llm.LlmCallback + +// Only needed for the multimodal ByteBuffer paths in the Images section. +import java.nio.ByteBuffer +import java.nio.ByteOrder ``` ### LlmModule @@ -264,7 +272,8 @@ module.prefillImages(normalized, 336, 336, 3); ByteBuffer floatBuffer = ByteBuffer .allocateDirect(3 * 336 * 336 * Float.BYTES) .order(ByteOrder.nativeOrder()); -// fill floatBuffer with normalized values, then: +// fill floatBuffer with normalized values, then rewind before the call: +floatBuffer.rewind(); module.prefillNormalizedImage(floatBuffer, 336, 336, 3); ``` @@ -276,7 +285,8 @@ module.prefillImages(normalized, 336, 336, 3) val floatBuffer: ByteBuffer = ByteBuffer .allocateDirect(3 * 336 * 336 * Float.SIZE_BYTES) .order(ByteOrder.nativeOrder()) -// fill floatBuffer with normalized values, then: +// fill floatBuffer with normalized values, then rewind before the call: +floatBuffer.rewind() module.prefillNormalizedImage(floatBuffer, 336, 336, 3) ```