Hello Everyone, i am Ashok. i am student i am working on a digits mnist classification project in the part of my internship. i would like to create on device machine learning training in android app.
reference:
i trained the model and i am facing this warning:
WARNING:absl:Importing a function (__inference_internal_grad_fn_368181) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
later, i converted it into tensorflow lite model.
i got stuck while creating the application. the error i am facing is:
java.lang.IllegalArgumentException: Cannot copy to a TensorFlowLite tensor (train_y:0) with 40 bytes from a Java Buffer with 8 byte
please help me. i am new to python and machine learning. I truly appreciate your help. thank you.
Java code
public class MainActivity extends AppCompatActivity { private ImageView imageView; private TextView textView; private Button selectImageButton, ProcessImage; private Button trainModelButton, Updateweights; private Button predictButton; private Bitmap image; private static final int NUM_EPOCHS = 100; private static final int BATCH_SIZE = 10; private static final int IMG_HEIGHT = 28; private static final int IMG_WIDTH = 28; private static final int NUM_TRAININGS = 60000; private static final int NUM_BATCHES = NUM_TRAININGS / BATCH_SIZE; private static final int NUM_IMAGES = 1; private static final int REQUEST_CODE_GALLERY = 1; private List<Bitmap> selectedImages; private List<FloatBuffer> trainImageBatches; private List<FloatBuffer> trainLabelBatches; private Button SelectImagesBtn, TrainModelBtn; // ByteBuffer modelBuffer; Interpreter modelBuffer; Bitmap bitmap; private static final int IMAGE_PICK_REQUEST_CODE = 1; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); imageView = findViewById(R.id.selected_image_view); textView = findViewById(R.id.improved_learning_rate_text_view); selectImageButton = findViewById(R.id.select_image_button); trainModelButton = findViewById(R.id.train_model_button); predictButton = findViewById(R.id.predict_number_button); ProcessImage = findViewById(R.id.process_image_button); try { modelBuffer = new Interpreter(loadModelFile()); } catch (Exception e) { Log.e("MainActivity", "Error loading TFLite model", e); } selectImageButton.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { selectImagesFromGallery(); } }); trainModelButton.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { trainModel(); Toast.makeText(getApplicationContext()," Train button is clicked",Toast.LENGTH_SHORT).show(); } }); predictButton.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { // LoadOndevicetrainedmodel(); // predictNumber(); Toast.makeText(getApplicationContext()," predict button is clicked",Toast.LENGTH_SHORT).show(); } }); ProcessImage.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { prepareTrainingBatches(); Toast.makeText(getApplicationContext()," process button is clicked",Toast.LENGTH_SHORT).show(); } }); } // Method to select images from the gallery private void selectImagesFromGallery() { // Use an Intent to pick images from the gallery Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI); intent.setType("image/*"); intent.putExtra(Intent.EXTRA_ALLOW_MULTIPLE, true); intent.setAction(Intent.ACTION_GET_CONTENT); startActivityForResult(Intent.createChooser(intent, "Select Images"), REQUEST_CODE_GALLERY); } @Override protected void onActivityResult(int requestCode, int resultCode, Intent data) { super.onActivityResult(requestCode, resultCode, data); if (requestCode == REQUEST_CODE_GALLERY && resultCode == RESULT_OK) { ClipData clipData = data.getClipData(); if (clipData != null) { selectedImages = new ArrayList<>(); int count = clipData.getItemCount(); count = Math.min(count, NUM_IMAGES); for (int i = 0; i < count; i++) { Uri imageUri = clipData.getItemAt(i).getUri(); try { bitmap = MediaStore.Images.Media.getBitmap(this.getContentResolver(), imageUri); // selectedImages.add(bitmap); imageView.setImageBitmap(bitmap); bitmap = resizeImage(bitmap); Toast.makeText(getApplicationContext(),"image converted to bitmap",Toast.LENGTH_LONG).show(); } catch (IOException e) { e.printStackTrace(); } } } } } // Method to prepare training batches using the selected images private void prepareTrainingBatches() { try { trainImageBatches = new ArrayList<>(NUM_BATCHES); trainLabelBatches = new ArrayList<>(NUM_BATCHES); // Iterate over the selected images for (int i = 0; i < NUM_IMAGES; i++) { // Allocate a direct buffer to store the image data // ByteBuffer byteBuffer = ByteBuffer.allocateDirect(IMG_HEIGHT * IMG_WIDTH * BATCH_SIZE).order(ByteOrder.nativeOrder()); FloatBuffer trainImages = convertBitmapToFloatBuffer(bitmap); // Convert the resized image to grayscale Bitmap grayscaleImage = toGrayscale(bitmap); // Convert the grayscale image to a float buffer FloatBuffer floatBuffer = convertBitmapToFloatBuffer(grayscaleImage); // Add the float buffer to trainImageBatches trainImageBatches.add(floatBuffer); // Allocate a direct buffer to store the label data ByteBuffer labelBuffer = ByteBuffer.allocateDirect(10 * BATCH_SIZE).order(ByteOrder.nativeOrder()); FloatBuffer trainLabels = labelBuffer.asFloatBuffer(); // Fill the image and label data for the current batch // trainImageBatches.add((FloatBuffer) trainImages.rewind()); trainLabelBatches.add((FloatBuffer) trainLabels.rewind()); Toast.makeText(getApplicationContext(), "prepareTrainingBatches is done", Toast.LENGTH_LONG).show(); } } catch (Exception e) { e.printStackTrace(); Toast.makeText(getApplicationContext(), "Error :"+ e, Toast.LENGTH_LONG).show(); } } public void trainModel(){ try { // Run training for a few steps. float[] losses = new float[NUM_EPOCHS]; for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) { for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) { Map<String, Object> inputs = new HashMap<>(); inputs.put("x", trainImageBatches.get(batchIdx)); inputs.put("y", trainLabelBatches.get(batchIdx)); Map<String, Object> outputs = new HashMap<>(); FloatBuffer loss = FloatBuffer.allocate(1); outputs.put("loss", loss); modelBuffer.runSignature(inputs, outputs, "train"); // Record the last loss. if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0); } // Print the loss output for every 10 epochs. if ((epoch + 1) % 10 == 0) { System.out.println( "Finished " + (epoch + 1) + " epochs, current loss: " + losses[epoch]); textView.setText("Finished " + (epoch + 1) + " epochs, current loss: " + losses[epoch]); } } // ... File outputFile = new File(getFilesDir(), "checkpoint.ckpt"); Map<String, Object> inputs = new HashMap<>(); inputs.put("checkpoint_path", outputFile.getAbsolutePath()); Map<String, Object> outputs = new HashMap<>(); modelBuffer.runSignature(inputs, outputs, "save"); } catch (Exception e){ e.printStackTrace(); Log.d("TRAIN MODEL:", String.valueOf(e)); Toast.makeText(getApplicationContext(),"Error:"+e,Toast.LENGTH_LONG).show(); } } private MappedByteBuffer loadModelFile() throws IOException { // Load the TensorFlow Lite model from a file AssetFileDescriptor fileDescriptor = getAssets().openFd("model.tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } private Bitmap resizeImage(Bitmap originalImage){ int width = originalImage.getWidth(); int height = originalImage.getHeight(); int newWidth = 28; int newHeight = 28; float scaleWidth = ((float) newWidth) / width; float scaleHeight = ((float) newHeight) / height; Matrix matrix = new Matrix(); matrix.postScale(scaleWidth, scaleHeight); // Bitmap resizedImage = Bitmap.createBitmap(originalImage, 0, 0, width, height, matrix, false); Bitmap resizedImage = Bitmap.createScaledBitmap(originalImage,newWidth,newHeight,true); return resizedImage; } // The toGrayscale() and convertBitmapToFloatBuffer() methods are defined as follows: public static Bitmap toGrayscale(Bitmap bmpOriginal) { int width, height; height = bmpOriginal.getHeight(); width = bmpOriginal.getWidth(); Bitmap bmpGrayscale = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888); Canvas c = new Canvas(bmpGrayscale); Paint paint = new Paint(); ColorMatrix cm = new ColorMatrix(); cm.setSaturation(0); ColorMatrixColorFilter f = new ColorMatrixColorFilter(cm); paint.setColorFilter(f); c.drawBitmap(bmpOriginal, 0, 0, paint); return bmpGrayscale; } public static FloatBuffer convertBitmapToFloatBuffer(Bitmap bitmap) { int width = bitmap.getWidth(); int height = bitmap.getHeight(); float[] floatValues = new float[width * height]; for (int i = 0; i < height; ++i) { for (int j = 0; j < width; ++j) { int pixelValue = bitmap.getPixel(j, i); floatValues[i * width + j] = (float)(pixelValue & 0xff) / 255.0f; } } FloatBuffer floatBuffer = FloatBuffer.wrap(floatValues); return floatBuffer; }