Skip to content

JAX Integration #475

@patrickvonplaten

Description

@patrickvonplaten

JAX Integration

This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.

General design:

We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:

ModelMixin: patil-suraj/stable-diffusion-jax#10 we should add a FlaxModelMixin class here.
FlaxDiffusionPipeline:

class DiffusionPipeline(ConfigMixin):
we should add a FlaxDiffusionPipeline here.

Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?

TODO:

Happy to take over 1. and finish today and then look into 4. once 3. is done.

@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)

  1. & 5. @pcuenca do you want to take this? (think 3. is more important here)

The other parts we can see tomorrow maybe :-)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions