diff --git a/README.md b/README.md
index 609a9f5..dd582e4 100644
--- a/README.md
+++ b/README.md
@@ -83,6 +83,7 @@ deepForceUpdate(rootInstance);
* Replaces static getters and setters
* Replaces unbound static methods
* Replaces static properties unless they were overwritten by code
+* Merges the initial state of new versions with existing component state
## Known Limitations
diff --git a/package.json b/package.json
index 0a11317..30a5148 100644
--- a/package.json
+++ b/package.json
@@ -66,6 +66,7 @@
"webpack": "1.4.8"
},
"dependencies": {
- "lodash": "^4.6.1"
+ "lodash": "^4.6.1",
+ "shallowequal": "^0.2.2"
}
}
diff --git a/src/createClassProxy.js b/src/createClassProxy.js
index b30f804..6094a39 100644
--- a/src/createClassProxy.js
+++ b/src/createClassProxy.js
@@ -2,6 +2,7 @@ import find from 'lodash/find';
import createPrototypeProxy from './createPrototypeProxy';
import bindAutoBindMethods from './bindAutoBindMethods';
import deleteUnknownAutoBindMethods from './deleteUnknownAutoBindMethods';
+import mergeState from './mergeState';
import supportsProtoAssignment from './supportsProtoAssignment';
const RESERVED_STATICS = [
@@ -173,6 +174,10 @@ function proxyClass(InitialComponent) {
// We might have added new methods that need to be auto-bound
mountedInstances.forEach(bindAutoBindMethods);
mountedInstances.forEach(deleteUnknownAutoBindMethods);
+
+ // Merge the initial state of the next component with
+ // the initial state of the current component
+ mountedInstances.forEach(instance => mergeState(instance, CurrentComponent));
}
};
diff --git a/src/mergeState.js b/src/mergeState.js
new file mode 100644
index 0000000..74c194d
--- /dev/null
+++ b/src/mergeState.js
@@ -0,0 +1,21 @@
+import assign from 'lodash/assign';
+import React, { Component } from 'react';
+import shallowEqual from 'shallowequal';
+
+export default function mergeState(component, NextComponent) {
+ if (component instanceof React.Component) {
+ // Modern components
+ const nextComponentInstance = new NextComponent(component.props);
+ const mergedState = assign({}, nextComponentInstance.state, component.state);
+ if (!shallowEqual(component.state || {}, mergedState)) {
+ component.setState(mergedState);
+ }
+ } else if (component.getInitialState) {
+ // Classic components
+ const mergedState = assign({}, component.getInitialState(), component.state);
+ if (!shallowEqual(component.state || {}, mergedState)) {
+ component.setState(mergedState);
+ }
+ }
+}
+
diff --git a/test/merge-state.js b/test/merge-state.js
new file mode 100644
index 0000000..1606dca
--- /dev/null
+++ b/test/merge-state.js
@@ -0,0 +1,150 @@
+import React, { Component } from 'react';
+import createShallowRenderer from './helpers/createShallowRenderer';
+import expect from 'expect';
+import createProxy from '../src';
+
+const fixtures = {
+ modern: {
+ VersionA: class VersionA extends React.Component {
+ constructor(props) {
+ super(props);
+ }
+
+ render() {
+ return
VersionA
;
+ }
+ },
+
+ VersionB: class VersionB extends React.Component {
+ constructor(props) {
+ super(props);
+
+ this.state = {
+ counter: 1,
+ };
+ }
+
+ render() {
+ return VersionB: {this.state.counter}
;
+ }
+ },
+
+ VersionC: class VersionC extends React.Component {
+ constructor(props) {
+ super(props);
+ this.state = {
+ counter: 1,
+ secondCounter: 1,
+ };
+ }
+
+ render() {
+ return VersionC: {this.state.counter}{this.state.secondCounter}
;
+ }
+ },
+
+ PropsDependent: class PropsDependent extends React.Component {
+ constructor(props) {
+ super(props);
+
+ this.state = { counter: props.defaultCounter };
+ }
+
+ render() {
+ return {this.state.counter}
;
+ }
+ },
+ },
+
+ classic: {
+ ClassicA: React.createClass({
+ render() {
+ return ClassicA
;
+ }
+ }),
+
+ ClassicB: React.createClass({
+ getInitialState() {
+ return { counter: 1 };
+ },
+
+ render() {
+ return ClassicB: {this.state.counter}
;
+ }
+ }),
+ },
+}
+
+describe('merging state', () => {
+ let renderer;
+ beforeEach(() => {
+ renderer = createShallowRenderer();
+ });
+
+ describe('modern', () => {
+ it('should merge initial state', () => {
+ const { VersionA, VersionB } = fixtures.modern;
+ const proxy = createProxy(VersionA);
+ const Proxy = proxy.get();
+ proxy.update(VersionB);
+ renderer.render();
+ expect(renderer.getRenderOutput().props.children).toEqual(['VersionB: ', 1]);
+ });
+
+ it('initializes state based on props', () => {
+ const { VersionA, PropsDependent } = fixtures.modern;
+ const proxy = createProxy(VersionA);
+ const Proxy = proxy.get();
+ renderer.render();
+ proxy.update(PropsDependent);
+ renderer.render();
+ expect(renderer.getRenderOutput().props.children).toEqual(3);
+ });
+
+ it('does not overwrite existing state', () => {
+ const { VersionA, VersionB } = fixtures.modern;
+ const proxy = createProxy(VersionA);
+ const Proxy = proxy.get();
+ const instance = renderer.render();
+ instance.setState({counter: 5});
+ proxy.update(VersionB);
+ renderer.render();
+ expect(renderer.getRenderOutput().props.children).toEqual(['VersionB: ', 5]);
+ });
+
+ it('supports multiple updates', () => {
+ const { VersionA, VersionB, VersionC } = fixtures.modern;
+ const proxy = createProxy(VersionA);
+ const Proxy = proxy.get();
+ const instance = renderer.render();
+ proxy.update(VersionB);
+ renderer.render();
+ proxy.update(VersionC);
+ renderer.render();
+ expect(renderer.getRenderOutput().props.children).toEqual(['VersionC: ', 1, 1]);
+ });
+ });
+
+ describe('classic', () => {
+ it('should merge initial state', () => {
+ const { ClassicA, ClassicB } = fixtures.classic;
+ const proxy = createProxy(ClassicA);
+ const Proxy = proxy.get();
+ proxy.update(ClassicB);
+ renderer.render();
+ expect(renderer.getRenderOutput().props.children).toEqual(['ClassicB: ', 1]);
+ });
+
+ it('does not overwrite existing state', () => {
+ const { ClassicA, ClassicB } = fixtures.classic;
+ const proxy = createProxy(ClassicA);
+ const Proxy = proxy.get();
+ const instance = renderer.render();
+ instance.setState({counter: 5});
+ proxy.update(ClassicB);
+ renderer.render();
+ expect(renderer.getRenderOutput().props.children).toEqual(['ClassicB: ', 5]);
+ });
+ });
+});
+