JAX Integration
This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.
General design:
We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:
ModelMixin: patil-suraj/stable-diffusion-jax#10 we should add a FlaxModelMixin class here.
FlaxDiffusionPipeline:
|
class DiffusionPipeline(ConfigMixin): |
we should add a
FlaxDiffusionPipeline here.
Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?
TODO:
Happy to take over 1. and finish today and then look into 4. once 3. is done.
@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)
- & 5. @pcuenca do you want to take this? (think 3. is more important here)
The other parts we can see tomorrow maybe :-)
JAX Integration
This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to
diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.General design:
We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:
ModelMixin: patil-suraj/stable-diffusion-jax#10 we should add aFlaxModelMixinclass here.FlaxDiffusionPipeline:diffusers/src/diffusers/pipeline_utils.py
Line 76 in 25a51b6
FlaxDiffusionPipelinehere.Note:
ModelMixinshould be made state-less by default. E.g. weights will not be saved. Also contrary totransformersshould we maybe only work withflax.linen.Moduleclasses here @patil-suraj - I don't really think we need theUNetConditionModelandUNetConditionModuledesign here - we could just go forclass UNetConditionModel(nn.Module):here and make sure everything stays stateless no?TODO:
diffusersframework independent. This will require some general changes tosetup.pyand our automation toolsFlaxModelMixin: ImplementFlaxModelMixin#493 Here we can take a lot from https://github.com/patil-suraj/stable-diffusion-jax/pull/10/files but I'm not sure we should follow thetransformersdesign here 1-to-1 . Will also ask some google-folks hereunet_2d_condition_flax.py...scheduling_pndm_flax.pyFlaxDiffusionPipelineHappy to take over 1. and finish today and then look into 4. once 3. is done.
@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)
The other parts we can see tomorrow maybe :-)