Back to Open Source

Datarax

Data Pipeline Framework for JAX

An extensible data pipeline framework built for JAX-based ML workflows with JIT compilation, differentiable processing, and multi-device distribution.

Repository Coming Soon

This project is under active development and will be open-sourced soon

Overview

Datarax (Data + Array/JAX) is an extensible data pipeline framework built for JAX-based machine learning workflows. It leverages JAX's JIT compilation, automatic differentiation, and hardware acceleration to build data loading, preprocessing, and augmentation pipelines that run on CPUs, GPUs, and TPUs.

JAX has mature libraries for models (Flax), optimizers (Optax), and checkpointing (Orbax), but lacks a dedicated data pipeline framework that operates at the same level of abstraction. Existing options are either framework-agnostic loaders that return NumPy arrays (losing JIT/autodiff benefits) or wrappers around tf.data/PyTorch that introduce cross-framework overhead. Datarax fills this gap.

Every component — sources, operators, batchers, samplers, sharders — is a Flax NNX module. Pipeline state is managed through NNX's variable system, which means operators can hold learnable parameters, be serialized with Orbax, and participate in JAX transformations (jit, vmap, grad) without special handling. Because operators are NNX modules, gradients flow through the entire pipeline, enabling approaches like gradient-based augmentation search and task-optimized preprocessing that are not possible with standard data loaders.

Pipelines are directed acyclic graphs, not linear chains. The >> operator composes sequential steps, | creates parallel branches, and control-flow nodes (Branch, Merge, SplitField) handle conditional and multi-path logic. The DAG executor manages scheduling, caching, and rebatching across the graph.

Key Features

JAX-Native Design

All core components built on JAX with Flax NNX module system. JIT-compiled pipelines via XLA with built-in profiling and roofline analysis.

Differentiable Pipelines

Operators are NNX modules so gradients flow through the entire pipeline. Enables gradient-based augmentation search and task-optimized preprocessing.

DAG Execution Engine

Graph-based pipeline construction with branching, parallel execution, caching, and rebatching nodes. Use >> for sequential and | for parallel composition.

Multi-Device Distribution

Scale from laptop to cluster with device mesh sharding, ArraySharder, and JaxProcessSharder for multi-host training.

Deterministic Reproducibility

Reproducible pipelines by default using Grain's Feistel cipher shuffling (O(1) memory) with explicit RNG key threading through every stochastic operator.

Ecosystem Integration

Works with Flax, Optax, Orbax, HuggingFace Datasets, TensorFlow Datasets, and ArrayRecord. Built-in benchmarking against 12+ frameworks.

Use Cases

1

High-throughput genomics data processing and quality control

2

Single-cell RNA-seq analysis pipelines with millions of cells

3

Gradient-based augmentation search replacing RL-based methods like AutoAugment

4

Task-optimized preprocessing by backpropagating task loss through processing stages

5

Multi-omics data integration and harmonization

6

Imaging data preprocessing for microscopy and spatial transcriptomics

7

Feature engineering for ML models in drug discovery

8

Distributed training data pipelines across GPUs and TPUs

Installation

# Basic installation
pip install datarax

# With data loading support (HuggingFace, TFDS, audio/image libs)
pip install datarax[data]

# With GPU support (CUDA 12)
pip install datarax[gpu]

# Full development installation
pip install datarax[all]

Quick Start

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig

# Create in-memory data source
data = {
    "image": np.random.randint(0, 255, (1000, 28, 28, 1)).astype(np.float32),
    "label": np.random.randint(0, 10, (1000,)).astype(np.int32),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))

# Build pipeline with DAG-based API
pipeline = (
    from_source(source, batch_size=32)
    >> OperatorNode(normalizer)
    >> OperatorNode(augmenter)
)

# Process batches
for i, batch in enumerate(pipeline):
    print(f"Batch {i}: images {batch['image'].shape}")

Built With

JAXFlax NNXOptaxOrbaxGrainHuggingFace DatasetsTensorFlow DatasetsNumPyApache Arrow

Ready to Get Started?

Explore the documentation, try examples, or contribute to the project.